diff --git a/bind/shared_bind.go b/bind/shared_bind.go new file mode 100644 index 0000000..bff66bf --- /dev/null +++ b/bind/shared_bind.go @@ -0,0 +1,378 @@ +//go:build !js + +package bind + +import ( + "fmt" + "net" + "net/netip" + "runtime" + "sync" + "sync/atomic" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + wgConn "golang.zx2c4.com/wireguard/conn" +) + +// Endpoint represents a network endpoint for the SharedBind +type Endpoint struct { + AddrPort netip.AddrPort +} + +// ClearSrc implements the wgConn.Endpoint interface +func (e *Endpoint) ClearSrc() {} + +// DstIP implements the wgConn.Endpoint interface +func (e *Endpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() +} + +// SrcIP implements the wgConn.Endpoint interface +func (e *Endpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +// DstToBytes implements the wgConn.Endpoint interface +func (e *Endpoint) DstToBytes() []byte { + b, _ := e.AddrPort.MarshalBinary() + return b +} + +// DstToString implements the wgConn.Endpoint interface +func (e *Endpoint) DstToString() string { + return e.AddrPort.String() +} + +// SrcToString implements the wgConn.Endpoint interface +func (e *Endpoint) SrcToString() string { + return "" +} + +// SharedBind is a thread-safe UDP bind that can be shared between WireGuard +// and hole punch senders. It wraps a single UDP connection and implements +// reference counting to prevent premature closure. +type SharedBind struct { + mu sync.RWMutex + + // The underlying UDP connection + udpConn *net.UDPConn + + // IPv4 and IPv6 packet connections for advanced features + ipv4PC *ipv4.PacketConn + ipv6PC *ipv6.PacketConn + + // Reference counting to prevent closing while in use + refCount atomic.Int32 + closed atomic.Bool + + // Channels for receiving data + recvFuncs []wgConn.ReceiveFunc + + // Port binding information + port uint16 +} + +// New creates a new SharedBind from an existing UDP connection. +// The SharedBind takes ownership of the connection and will close it +// when all references are released. +func New(udpConn *net.UDPConn) (*SharedBind, error) { + if udpConn == nil { + return nil, fmt.Errorf("udpConn cannot be nil") + } + + bind := &SharedBind{ + udpConn: udpConn, + } + + // Initialize reference count to 1 (the creator holds the first reference) + bind.refCount.Store(1) + + // Get the local port + if addr, ok := udpConn.LocalAddr().(*net.UDPAddr); ok { + bind.port = uint16(addr.Port) + } + + return bind, nil +} + +// AddRef increments the reference count. Call this when sharing +// the bind with another component. +func (b *SharedBind) AddRef() { + newCount := b.refCount.Add(1) + // Optional: Add logging for debugging + _ = newCount // Placeholder for potential logging +} + +// Release decrements the reference count. When it reaches zero, +// the underlying UDP connection is closed. +func (b *SharedBind) Release() error { + newCount := b.refCount.Add(-1) + // Optional: Add logging for debugging + _ = newCount // Placeholder for potential logging + + if newCount < 0 { + // This should never happen with proper usage + b.refCount.Store(0) + return fmt.Errorf("SharedBind reference count went negative") + } + + if newCount == 0 { + return b.closeConnection() + } + + return nil +} + +// closeConnection actually closes the UDP connection +func (b *SharedBind) closeConnection() error { + if !b.closed.CompareAndSwap(false, true) { + // Already closed + return nil + } + + b.mu.Lock() + defer b.mu.Unlock() + + var err error + if b.udpConn != nil { + err = b.udpConn.Close() + b.udpConn = nil + } + + b.ipv4PC = nil + b.ipv6PC = nil + + return err +} + +// GetUDPConn returns the underlying UDP connection. +// The caller must not close this connection directly. +func (b *SharedBind) GetUDPConn() *net.UDPConn { + b.mu.RLock() + defer b.mu.RUnlock() + return b.udpConn +} + +// GetRefCount returns the current reference count (for debugging) +func (b *SharedBind) GetRefCount() int32 { + return b.refCount.Load() +} + +// IsClosed returns whether the bind is closed +func (b *SharedBind) IsClosed() bool { + return b.closed.Load() +} + +// WriteToUDP writes data to a specific UDP address. +// This is thread-safe and can be used by hole punch senders. +func (b *SharedBind) WriteToUDP(data []byte, addr *net.UDPAddr) (int, error) { + if b.closed.Load() { + return 0, net.ErrClosed + } + + b.mu.RLock() + conn := b.udpConn + b.mu.RUnlock() + + if conn == nil { + return 0, net.ErrClosed + } + + return conn.WriteToUDP(data, addr) +} + +// Close implements the WireGuard Bind interface. +// It decrements the reference count and closes the connection if no references remain. +func (b *SharedBind) Close() error { + return b.Release() +} + +// Open implements the WireGuard Bind interface. +// Since the connection is already open, this just sets up the receive functions. +func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { + if b.closed.Load() { + return nil, 0, net.ErrClosed + } + + b.mu.Lock() + defer b.mu.Unlock() + + if b.udpConn == nil { + return nil, 0, net.ErrClosed + } + + // Set up IPv4 and IPv6 packet connections for advanced features + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + b.ipv4PC = ipv4.NewPacketConn(b.udpConn) + b.ipv6PC = ipv6.NewPacketConn(b.udpConn) + } + + // Create receive functions + recvFuncs := make([]wgConn.ReceiveFunc, 0, 2) + + // Add IPv4 receive function + if b.ipv4PC != nil || runtime.GOOS != "linux" { + recvFuncs = append(recvFuncs, b.makeReceiveIPv4()) + } + + // Add IPv6 receive function if needed + // For now, we focus on IPv4 for hole punching use case + + b.recvFuncs = recvFuncs + return recvFuncs, b.port, nil +} + +// makeReceiveIPv4 creates a receive function for IPv4 packets +func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { + if b.closed.Load() { + return 0, net.ErrClosed + } + + b.mu.RLock() + conn := b.udpConn + pc := b.ipv4PC + b.mu.RUnlock() + + if conn == nil { + return 0, net.ErrClosed + } + + // Use batch reading on Linux for performance + if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { + return b.receiveIPv4Batch(pc, bufs, sizes, eps) + } + + // Fallback to simple read for other platforms + return b.receiveIPv4Simple(conn, bufs, sizes, eps) + } +} + +// receiveIPv4Batch uses batch reading for better performance on Linux +func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { + // Create messages for batch reading + msgs := make([]ipv4.Message, len(bufs)) + for i := range bufs { + msgs[i].Buffers = [][]byte{bufs[i]} + msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use + } + + numMsgs, err := pc.ReadBatch(msgs, 0) + if err != nil { + return 0, err + } + + for i := 0; i < numMsgs; i++ { + sizes[i] = msgs[i].N + if sizes[i] == 0 { + continue + } + + if msgs[i].Addr != nil { + if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok { + addrPort := udpAddr.AddrPort() + eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort} + } + } + } + + return numMsgs, nil +} + +// receiveIPv4Simple uses simple ReadFromUDP for non-Linux platforms +func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { + n, addr, err := conn.ReadFromUDP(bufs[0]) + if err != nil { + return 0, err + } + + sizes[0] = n + if addr != nil { + addrPort := addr.AddrPort() + eps[0] = &wgConn.StdNetEndpoint{AddrPort: addrPort} + } + + return 1, nil +} + +// Send implements the WireGuard Bind interface. +// It sends packets to the specified endpoint. +func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { + if b.closed.Load() { + return net.ErrClosed + } + + b.mu.RLock() + conn := b.udpConn + b.mu.RUnlock() + + if conn == nil { + return net.ErrClosed + } + + // Extract the destination address from the endpoint + var destAddr *net.UDPAddr + + // Try to cast to StdNetEndpoint first + if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok { + destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort) + } else { + // Fallback: construct from DstIP and DstToBytes + dstBytes := ep.DstToBytes() + if len(dstBytes) >= 6 { // Minimum for IPv4 (4 bytes) + port (2 bytes) + var addr netip.Addr + var port uint16 + + if len(dstBytes) >= 18 { // IPv6 (16 bytes) + port (2 bytes) + addr, _ = netip.AddrFromSlice(dstBytes[:16]) + port = uint16(dstBytes[16]) | uint16(dstBytes[17])<<8 + } else { // IPv4 + addr, _ = netip.AddrFromSlice(dstBytes[:4]) + port = uint16(dstBytes[4]) | uint16(dstBytes[5])<<8 + } + + if addr.IsValid() { + destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port)) + } + } + } + + if destAddr == nil { + return fmt.Errorf("could not extract destination address from endpoint") + } + + // Send all buffers to the destination + for _, buf := range bufs { + _, err := conn.WriteToUDP(buf, destAddr) + if err != nil { + return err + } + } + + return nil +} + +// SetMark implements the WireGuard Bind interface. +// It's a no-op for this implementation. +func (b *SharedBind) SetMark(mark uint32) error { + // Not implemented for this use case + return nil +} + +// BatchSize returns the preferred batch size for sending packets. +func (b *SharedBind) BatchSize() int { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + return wgConn.IdealBatchSize + } + return 1 +} + +// ParseEndpoint creates a new endpoint from a string address. +func (b *SharedBind) ParseEndpoint(s string) (wgConn.Endpoint, error) { + addrPort, err := netip.ParseAddrPort(s) + if err != nil { + return nil, err + } + return &wgConn.StdNetEndpoint{AddrPort: addrPort}, nil +} diff --git a/bind/shared_bind_test.go b/bind/shared_bind_test.go new file mode 100644 index 0000000..6e1ec66 --- /dev/null +++ b/bind/shared_bind_test.go @@ -0,0 +1,424 @@ +//go:build !js + +package bind + +import ( + "net" + "net/netip" + "sync" + "testing" + "time" + + wgConn "golang.zx2c4.com/wireguard/conn" +) + +// TestSharedBindCreation tests basic creation and initialization +func TestSharedBindCreation(t *testing.T) { + // Create a UDP connection + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + defer udpConn.Close() + + // Create SharedBind + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + + if bind == nil { + t.Fatal("SharedBind is nil") + } + + // Verify initial reference count + if bind.refCount.Load() != 1 { + t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load()) + } + + // Clean up + if err := bind.Close(); err != nil { + t.Errorf("Failed to close SharedBind: %v", err) + } +} + +// TestSharedBindReferenceCount tests reference counting +func TestSharedBindReferenceCount(t *testing.T) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + + // Add references + bind.AddRef() + if bind.refCount.Load() != 2 { + t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load()) + } + + bind.AddRef() + if bind.refCount.Load() != 3 { + t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load()) + } + + // Release references + bind.Release() + if bind.refCount.Load() != 2 { + t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load()) + } + + bind.Release() + bind.Release() // This should close the connection + + if !bind.closed.Load() { + t.Error("Expected bind to be closed after all references released") + } +} + +// TestSharedBindWriteToUDP tests the WriteToUDP functionality +func TestSharedBindWriteToUDP(t *testing.T) { + // Create sender + senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create sender UDP connection: %v", err) + } + + senderBind, err := New(senderConn) + if err != nil { + t.Fatalf("Failed to create sender SharedBind: %v", err) + } + defer senderBind.Close() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + // Send data + testData := []byte("Hello, SharedBind!") + n, err := senderBind.WriteToUDP(testData, receiverAddr) + if err != nil { + t.Fatalf("WriteToUDP failed: %v", err) + } + + if n != len(testData) { + t.Errorf("Expected to send %d bytes, sent %d", len(testData), n) + } + + // Receive data + buf := make([]byte, 1024) + receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, _, err = receiverConn.ReadFromUDP(buf) + if err != nil { + t.Fatalf("Failed to receive data: %v", err) + } + + if string(buf[:n]) != string(testData) { + t.Errorf("Expected to receive %q, got %q", testData, buf[:n]) + } +} + +// TestSharedBindConcurrentWrites tests thread-safety +func TestSharedBindConcurrentWrites(t *testing.T) { + // Create sender + senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create sender UDP connection: %v", err) + } + + senderBind, err := New(senderConn) + if err != nil { + t.Fatalf("Failed to create sender SharedBind: %v", err) + } + defer senderBind.Close() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + // Launch concurrent writes + numGoroutines := 100 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + data := []byte{byte(id)} + _, err := senderBind.WriteToUDP(data, receiverAddr) + if err != nil { + t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err) + } + }(i) + } + + wg.Wait() +} + +// TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation +func TestSharedBindWireGuardInterface(t *testing.T) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer bind.Close() + + // Test Open + recvFuncs, port, err := bind.Open(0) + if err != nil { + t.Fatalf("Open failed: %v", err) + } + + if len(recvFuncs) == 0 { + t.Error("Expected at least one receive function") + } + + if port == 0 { + t.Error("Expected non-zero port") + } + + // Test SetMark (should be a no-op) + if err := bind.SetMark(0); err != nil { + t.Errorf("SetMark failed: %v", err) + } + + // Test BatchSize + batchSize := bind.BatchSize() + if batchSize <= 0 { + t.Error("Expected positive batch size") + } +} + +// TestSharedBindSend tests the Send method with WireGuard endpoints +func TestSharedBindSend(t *testing.T) { + // Create sender + senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create sender UDP connection: %v", err) + } + + senderBind, err := New(senderConn) + if err != nil { + t.Fatalf("Failed to create sender SharedBind: %v", err) + } + defer senderBind.Close() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + // Create an endpoint + addrPort := receiverAddr.AddrPort() + endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort} + + // Send data + testData := []byte("WireGuard packet") + bufs := [][]byte{testData} + err = senderBind.Send(bufs, endpoint) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + + // Receive data + buf := make([]byte, 1024) + receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, _, err := receiverConn.ReadFromUDP(buf) + if err != nil { + t.Fatalf("Failed to receive data: %v", err) + } + + if string(buf[:n]) != string(testData) { + t.Errorf("Expected to receive %q, got %q", testData, buf[:n]) + } +} + +// TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind +func TestSharedBindMultipleUsers(t *testing.T) { + // Create shared bind + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + sharedBind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + + // Add reference for hole punch sender + sharedBind.AddRef() + + // Create receiver + receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create receiver UDP connection: %v", err) + } + defer receiverConn.Close() + + receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) + + var wg sync.WaitGroup + + // Simulate WireGuard using the bind + wg.Add(1) + go func() { + defer wg.Done() + addrPort := receiverAddr.AddrPort() + endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort} + + for i := 0; i < 10; i++ { + data := []byte("WireGuard packet") + bufs := [][]byte{data} + if err := sharedBind.Send(bufs, endpoint); err != nil { + t.Errorf("WireGuard Send failed: %v", err) + } + time.Sleep(10 * time.Millisecond) + } + }() + + // Simulate hole punch sender using the bind + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + data := []byte("Hole punch packet") + if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil { + t.Errorf("Hole punch WriteToUDP failed: %v", err) + } + time.Sleep(10 * time.Millisecond) + } + }() + + wg.Wait() + + // Release the hole punch reference + sharedBind.Release() + + // Close WireGuard's reference (should close the connection) + sharedBind.Close() + + if !sharedBind.closed.Load() { + t.Error("Expected bind to be closed after all users released it") + } +} + +// TestEndpoint tests the Endpoint implementation +func TestEndpoint(t *testing.T) { + addr := netip.MustParseAddr("192.168.1.1") + addrPort := netip.AddrPortFrom(addr, 51820) + + ep := &Endpoint{AddrPort: addrPort} + + // Test DstIP + if ep.DstIP() != addr { + t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP()) + } + + // Test DstToString + expected := "192.168.1.1:51820" + if ep.DstToString() != expected { + t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString()) + } + + // Test DstToBytes + bytes := ep.DstToBytes() + if len(bytes) == 0 { + t.Error("Expected DstToBytes to return non-empty slice") + } + + // Test SrcIP (should be zero) + if ep.SrcIP().IsValid() { + t.Error("Expected SrcIP to be invalid") + } + + // Test ClearSrc (should not panic) + ep.ClearSrc() +} + +// TestParseEndpoint tests the ParseEndpoint method +func TestParseEndpoint(t *testing.T) { + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create UDP connection: %v", err) + } + + bind, err := New(udpConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer bind.Close() + + tests := []struct { + name string + input string + wantErr bool + checkAddr func(*testing.T, wgConn.Endpoint) + }{ + { + name: "valid IPv4", + input: "192.168.1.1:51820", + wantErr: false, + checkAddr: func(t *testing.T, ep wgConn.Endpoint) { + if ep.DstToString() != "192.168.1.1:51820" { + t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString()) + } + }, + }, + { + name: "valid IPv6", + input: "[::1]:51820", + wantErr: false, + checkAddr: func(t *testing.T, ep wgConn.Endpoint) { + if ep.DstToString() != "[::1]:51820" { + t.Errorf("Expected [::1]:51820, got %s", ep.DstToString()) + } + }, + }, + { + name: "invalid - missing port", + input: "192.168.1.1", + wantErr: true, + }, + { + name: "invalid - bad format", + input: "not-an-address", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ep, err := bind.ParseEndpoint(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && tt.checkAddr != nil { + tt.checkAddr(t, ep) + } + }) + } +} diff --git a/olm-binary.REMOVED.git-id b/olm-binary.REMOVED.git-id new file mode 100644 index 0000000..78de5d4 --- /dev/null +++ b/olm-binary.REMOVED.git-id @@ -0,0 +1 @@ +767662d6fa777b3bb77d47a1c44eb5fb60249e87 \ No newline at end of file diff --git a/olm-test.REMOVED.git-id b/olm-test.REMOVED.git-id new file mode 100644 index 0000000..60202ca --- /dev/null +++ b/olm-test.REMOVED.git-id @@ -0,0 +1 @@ +ba2c118fd96937229ef54dcd0b82fe5d53d94a87 \ No newline at end of file diff --git a/olm/common.go b/olm/common.go index 7da0aa9..f082a6a 100644 --- a/olm/common.go +++ b/olm/common.go @@ -14,13 +14,13 @@ import ( "time" "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/bind" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" "github.com/vishvananda/netlink" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/curve25519" "golang.org/x/exp/rand" - "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -82,11 +82,6 @@ const ( ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" ) -type fixedPortBind struct { - port uint16 - conn.Bind -} - // PeerAction represents a request to add, update, or remove a peer type PeerAction struct { Action string `json:"action"` // "add", "update", or "remove" @@ -124,11 +119,6 @@ type RelayPeerData struct { PublicKey string `json:"publicKey"` } -func (b *fixedPortBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) { - // Ignore the requested port and use our fixed port - return b.Bind.Open(b.port) -} - // Helper function to format endpoints correctly func formatEndpoint(endpoint string) string { if endpoint == "" { @@ -156,13 +146,6 @@ func formatEndpoint(endpoint string) string { return endpoint } -func NewFixedPortBind(port uint16) conn.Bind { - return &fixedPortBind{ - port: port, - Bind: conn.NewDefaultBind(), - } -} - func fixKey(key string) string { // Remove any whitespace key = strings.TrimSpace(key) @@ -523,6 +506,196 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, s } } +// keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind sends hole punch packets using the shared bind +func keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(exitNodes []ExitNode, olmID string, sharedBind *bind.SharedBind) { + if len(exitNodes) == 0 { + logger.Warn("No exit nodes provided for hole punching") + return + } + + // Check if hole punching is already running + if holePunchRunning { + logger.Debug("UDP hole punch already running, skipping new request") + return + } + + // Set the flag to indicate hole punching is running + holePunchRunning = true + defer func() { + holePunchRunning = false + logger.Info("UDP hole punch goroutine ended") + }() + + logger.Info("Starting UDP hole punch to %d exit nodes with shared bind", len(exitNodes)) + defer logger.Info("UDP hole punch goroutine ended for all exit nodes") + + // Resolve all endpoints upfront + type resolvedExitNode struct { + remoteAddr *net.UDPAddr + publicKey string + endpointName string + } + + var resolvedNodes []resolvedExitNode + for _, exitNode := range exitNodes { + host, err := resolveDomain(exitNode.Endpoint) + if err != nil { + logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) + continue + } + + serverAddr := net.JoinHostPort(host, "21820") + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err) + continue + } + + resolvedNodes = append(resolvedNodes, resolvedExitNode{ + remoteAddr: remoteAddr, + publicKey: exitNode.PublicKey, + endpointName: exitNode.Endpoint, + }) + logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) + } + + if len(resolvedNodes) == 0 { + logger.Error("No exit nodes could be resolved") + return + } + + // Send initial hole punch to all exit nodes + for _, node := range resolvedNodes { + if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil { + logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err) + } + } + + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + timeout := time.NewTimer(15 * time.Second) + defer timeout.Stop() + + for { + select { + case <-stopHolepunch: + logger.Info("Stopping UDP holepunch for all exit nodes") + return + case <-timeout.C: + logger.Info("UDP holepunch routine timed out after 15 seconds for all exit nodes") + return + case <-ticker.C: + // Send hole punch to all exit nodes + for _, node := range resolvedNodes { + if err := sendUDPHolePunchWithBind(sharedBind, node.remoteAddr, olmID, node.publicKey); err != nil { + logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err) + } + } + } + } +} + +// keepSendingUDPHolePunchWithSharedBind sends hole punch packets to a single endpoint using shared bind +func keepSendingUDPHolePunchWithSharedBind(endpoint string, olmID string, sharedBind *bind.SharedBind, serverPubKey string) { + // Check if hole punching is already running + if holePunchRunning { + logger.Debug("UDP hole punch already running, skipping new request") + return + } + + // Set the flag to indicate hole punching is running + holePunchRunning = true + defer func() { + holePunchRunning = false + logger.Info("UDP hole punch goroutine ended") + }() + + logger.Info("Starting UDP hole punch to %s with shared bind", endpoint) + defer logger.Info("UDP hole punch goroutine ended for %s", endpoint) + + host, err := resolveDomain(endpoint) + if err != nil { + logger.Error("Failed to resolve domain %s: %v", endpoint, err) + return + } + + serverAddr := net.JoinHostPort(host, "21820") + + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address %s: %v", serverAddr, err) + return + } + + // Execute once immediately before starting the loop + if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil { + logger.Error("Failed to send initial UDP hole punch: %v", err) + } + + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + timeout := time.NewTimer(15 * time.Second) + defer timeout.Stop() + + for { + select { + case <-stopHolepunch: + logger.Info("Stopping UDP holepunch") + return + case <-timeout.C: + logger.Info("UDP holepunch routine timed out after 15 seconds") + return + case <-ticker.C: + if err := sendUDPHolePunchWithBind(sharedBind, remoteAddr, olmID, serverPubKey); err != nil { + logger.Error("Failed to send UDP hole punch: %v", err) + } + } + } +} + +// sendUDPHolePunchWithBind sends an encrypted hole punch packet using the shared bind +func sendUDPHolePunchWithBind(sharedBind *bind.SharedBind, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error { + if serverPubKey == "" || olmToken == "" { + return fmt.Errorf("server public key or OLM token is empty") + } + + payload := struct { + OlmID string `json:"olmId"` + Token string `json:"token"` + }{ + OlmID: olmID, + Token: olmToken, + } + + // Convert payload to JSON + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + // Encrypt the payload using the server's WireGuard public key + encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey) + if err != nil { + return fmt.Errorf("failed to encrypt payload: %w", err) + } + + jsonData, err := json.Marshal(encryptedPayload) + if err != nil { + return fmt.Errorf("failed to marshal encrypted payload: %w", err) + } + + _, err = sharedBind.WriteToUDP(jsonData, remoteAddr) + if err != nil { + return fmt.Errorf("failed to write to UDP: %w", err) + } + + logger.Debug("Sent UDP hole punch to %s: %s", remoteAddr.String(), string(jsonData)) + + return nil +} + func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { if maxPort < minPort { return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) diff --git a/olm/olm.go b/olm/olm.go index 895acd9..7821a32 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -12,6 +12,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/updates" "github.com/fosrl/olm/api" + "github.com/fosrl/olm/bind" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" "golang.zx2c4.com/wireguard/device" @@ -67,6 +68,7 @@ var ( olmClient *websocket.Client tunnelCancel context.CancelFunc tunnelRunning bool + sharedBind *bind.SharedBind ) func Run(ctx context.Context, config Config) { @@ -226,10 +228,36 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, return } - sourcePort, err := FindAvailableUDPPort(49152, 65535) - if err != nil { - logger.Error("Error finding available port: %v", err) - return + // Create shared UDP socket for both holepunch and WireGuard + if sharedBind == nil { + sourcePort, err := FindAvailableUDPPort(49152, 65535) + if err != nil { + logger.Error("Error finding available port: %v", err) + return + } + + localAddr := &net.UDPAddr{ + Port: int(sourcePort), + IP: net.IPv4zero, + } + + udpConn, err := net.ListenUDP("udp", localAddr) + if err != nil { + logger.Error("Failed to create shared UDP socket: %v", err) + return + } + + sharedBind, err = bind.New(udpConn) + if err != nil { + logger.Error("Failed to create shared bind: %v", err) + udpConn.Close() + return + } + + // Add a reference for the hole punch senders (creator already has one reference for WireGuard) + sharedBind.AddRef() + + logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) } olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { @@ -251,7 +279,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, // Start a single hole punch goroutine for all exit nodes logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes)) - go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort) + go keepSendingUDPHolePunchToMultipleExitNodesWithSharedBind(holePunchData.ExitNodes, id, sharedBind) }) olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { @@ -289,7 +317,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, // Start hole punching for each exit node logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey) - go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey) + go keepSendingUDPHolePunchWithSharedBind(legacyHolePunchData.Endpoint, id, sharedBind, legacyHolePunchData.ServerPubKey) }) olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { @@ -305,7 +333,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, stopRegister = nil } - close(stopHolepunch) + // close(stopHolepunch) // wait 10 milliseconds to ensure the previous connection is closed logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed") @@ -367,7 +395,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, return } - dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) + dev = device.NewDevice(tdev, sharedBind, device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) uapiListener, err = uapiListen(interfaceName, fileUAPI) if err != nil { @@ -804,7 +832,7 @@ func Stop() { uapiListener = nil } if dev != nil { - dev.Close() + dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference dev = nil } // Close TUN device @@ -813,6 +841,15 @@ func Stop() { tdev = nil } + // Release the hole punch reference to the shared bind + if sharedBind != nil { + // Release hole punch reference (WireGuard already released its reference via dev.Close()) + logger.Debug("Releasing shared bind (refcount before release: %d)", sharedBind.GetRefCount()) + sharedBind.Release() + sharedBind = nil + logger.Info("Released shared UDP bind") + } + logger.Info("Olm service stopped") }