//go:build !js package bind import ( "net" "net/netip" "sync" "testing" "time" wgConn "golang.zx2c4.com/wireguard/conn" ) // TestSharedBindCreation tests basic creation and initialization func TestSharedBindCreation(t *testing.T) { // Create a UDP connection udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create UDP connection: %v", err) } defer udpConn.Close() // Create SharedBind bind, err := New(udpConn) if err != nil { t.Fatalf("Failed to create SharedBind: %v", err) } if bind == nil { t.Fatal("SharedBind is nil") } // Verify initial reference count if bind.refCount.Load() != 1 { t.Errorf("Expected initial refCount to be 1, got %d", bind.refCount.Load()) } // Clean up if err := bind.Close(); err != nil { t.Errorf("Failed to close SharedBind: %v", err) } } // TestSharedBindReferenceCount tests reference counting func TestSharedBindReferenceCount(t *testing.T) { udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create UDP connection: %v", err) } bind, err := New(udpConn) if err != nil { t.Fatalf("Failed to create SharedBind: %v", err) } // Add references bind.AddRef() if bind.refCount.Load() != 2 { t.Errorf("Expected refCount to be 2, got %d", bind.refCount.Load()) } bind.AddRef() if bind.refCount.Load() != 3 { t.Errorf("Expected refCount to be 3, got %d", bind.refCount.Load()) } // Release references bind.Release() if bind.refCount.Load() != 2 { t.Errorf("Expected refCount to be 2 after release, got %d", bind.refCount.Load()) } bind.Release() bind.Release() // This should close the connection if !bind.closed.Load() { t.Error("Expected bind to be closed after all references released") } } // TestSharedBindWriteToUDP tests the WriteToUDP functionality func TestSharedBindWriteToUDP(t *testing.T) { // Create sender senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create sender UDP connection: %v", err) } senderBind, err := New(senderConn) if err != nil { t.Fatalf("Failed to create sender SharedBind: %v", err) } defer senderBind.Close() // Create receiver receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create receiver UDP connection: %v", err) } defer receiverConn.Close() receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) // Send data testData := []byte("Hello, SharedBind!") n, err := senderBind.WriteToUDP(testData, receiverAddr) if err != nil { t.Fatalf("WriteToUDP failed: %v", err) } if n != len(testData) { t.Errorf("Expected to send %d bytes, sent %d", len(testData), n) } // Receive data buf := make([]byte, 1024) receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second)) n, _, err = receiverConn.ReadFromUDP(buf) if err != nil { t.Fatalf("Failed to receive data: %v", err) } if string(buf[:n]) != string(testData) { t.Errorf("Expected to receive %q, got %q", testData, buf[:n]) } } // TestSharedBindConcurrentWrites tests thread-safety func TestSharedBindConcurrentWrites(t *testing.T) { // Create sender senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create sender UDP connection: %v", err) } senderBind, err := New(senderConn) if err != nil { t.Fatalf("Failed to create sender SharedBind: %v", err) } defer senderBind.Close() // Create receiver receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create receiver UDP connection: %v", err) } defer receiverConn.Close() receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) // Launch concurrent writes numGoroutines := 100 var wg sync.WaitGroup wg.Add(numGoroutines) for i := 0; i < numGoroutines; i++ { go func(id int) { defer wg.Done() data := []byte{byte(id)} _, err := senderBind.WriteToUDP(data, receiverAddr) if err != nil { t.Errorf("WriteToUDP failed in goroutine %d: %v", id, err) } }(i) } wg.Wait() } // TestSharedBindWireGuardInterface tests WireGuard Bind interface implementation func TestSharedBindWireGuardInterface(t *testing.T) { udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create UDP connection: %v", err) } bind, err := New(udpConn) if err != nil { t.Fatalf("Failed to create SharedBind: %v", err) } defer bind.Close() // Test Open recvFuncs, port, err := bind.Open(0) if err != nil { t.Fatalf("Open failed: %v", err) } if len(recvFuncs) == 0 { t.Error("Expected at least one receive function") } if port == 0 { t.Error("Expected non-zero port") } // Test SetMark (should be a no-op) if err := bind.SetMark(0); err != nil { t.Errorf("SetMark failed: %v", err) } // Test BatchSize batchSize := bind.BatchSize() if batchSize <= 0 { t.Error("Expected positive batch size") } } // TestSharedBindSend tests the Send method with WireGuard endpoints func TestSharedBindSend(t *testing.T) { // Create sender senderConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create sender UDP connection: %v", err) } senderBind, err := New(senderConn) if err != nil { t.Fatalf("Failed to create sender SharedBind: %v", err) } defer senderBind.Close() // Create receiver receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create receiver UDP connection: %v", err) } defer receiverConn.Close() receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) // Create an endpoint addrPort := receiverAddr.AddrPort() endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort} // Send data testData := []byte("WireGuard packet") bufs := [][]byte{testData} err = senderBind.Send(bufs, endpoint) if err != nil { t.Fatalf("Send failed: %v", err) } // Receive data buf := make([]byte, 1024) receiverConn.SetReadDeadline(time.Now().Add(2 * time.Second)) n, _, err := receiverConn.ReadFromUDP(buf) if err != nil { t.Fatalf("Failed to receive data: %v", err) } if string(buf[:n]) != string(testData) { t.Errorf("Expected to receive %q, got %q", testData, buf[:n]) } } // TestSharedBindMultipleUsers simulates WireGuard and hole punch using the same bind func TestSharedBindMultipleUsers(t *testing.T) { // Create shared bind udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create UDP connection: %v", err) } sharedBind, err := New(udpConn) if err != nil { t.Fatalf("Failed to create SharedBind: %v", err) } // Add reference for hole punch sender sharedBind.AddRef() // Create receiver receiverConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create receiver UDP connection: %v", err) } defer receiverConn.Close() receiverAddr := receiverConn.LocalAddr().(*net.UDPAddr) var wg sync.WaitGroup // Simulate WireGuard using the bind wg.Add(1) go func() { defer wg.Done() addrPort := receiverAddr.AddrPort() endpoint := &wgConn.StdNetEndpoint{AddrPort: addrPort} for i := 0; i < 10; i++ { data := []byte("WireGuard packet") bufs := [][]byte{data} if err := sharedBind.Send(bufs, endpoint); err != nil { t.Errorf("WireGuard Send failed: %v", err) } time.Sleep(10 * time.Millisecond) } }() // Simulate hole punch sender using the bind wg.Add(1) go func() { defer wg.Done() for i := 0; i < 10; i++ { data := []byte("Hole punch packet") if _, err := sharedBind.WriteToUDP(data, receiverAddr); err != nil { t.Errorf("Hole punch WriteToUDP failed: %v", err) } time.Sleep(10 * time.Millisecond) } }() wg.Wait() // Release the hole punch reference sharedBind.Release() // Close WireGuard's reference (should close the connection) sharedBind.Close() if !sharedBind.closed.Load() { t.Error("Expected bind to be closed after all users released it") } } // TestEndpoint tests the Endpoint implementation func TestEndpoint(t *testing.T) { addr := netip.MustParseAddr("192.168.1.1") addrPort := netip.AddrPortFrom(addr, 51820) ep := &Endpoint{AddrPort: addrPort} // Test DstIP if ep.DstIP() != addr { t.Errorf("Expected DstIP to be %v, got %v", addr, ep.DstIP()) } // Test DstToString expected := "192.168.1.1:51820" if ep.DstToString() != expected { t.Errorf("Expected DstToString to be %q, got %q", expected, ep.DstToString()) } // Test DstToBytes bytes := ep.DstToBytes() if len(bytes) == 0 { t.Error("Expected DstToBytes to return non-empty slice") } // Test SrcIP (should be zero) if ep.SrcIP().IsValid() { t.Error("Expected SrcIP to be invalid") } // Test ClearSrc (should not panic) ep.ClearSrc() } // TestParseEndpoint tests the ParseEndpoint method func TestParseEndpoint(t *testing.T) { udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create UDP connection: %v", err) } bind, err := New(udpConn) if err != nil { t.Fatalf("Failed to create SharedBind: %v", err) } defer bind.Close() tests := []struct { name string input string wantErr bool checkAddr func(*testing.T, wgConn.Endpoint) }{ { name: "valid IPv4", input: "192.168.1.1:51820", wantErr: false, checkAddr: func(t *testing.T, ep wgConn.Endpoint) { if ep.DstToString() != "192.168.1.1:51820" { t.Errorf("Expected 192.168.1.1:51820, got %s", ep.DstToString()) } }, }, { name: "valid IPv6", input: "[::1]:51820", wantErr: false, checkAddr: func(t *testing.T, ep wgConn.Endpoint) { if ep.DstToString() != "[::1]:51820" { t.Errorf("Expected [::1]:51820, got %s", ep.DstToString()) } }, }, { name: "invalid - missing port", input: "192.168.1.1", wantErr: true, }, { name: "invalid - bad format", input: "not-an-address", wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ep, err := bind.ParseEndpoint(tt.input) if (err != nil) != tt.wantErr { t.Errorf("ParseEndpoint() error = %v, wantErr %v", err, tt.wantErr) return } if !tt.wantErr && tt.checkAddr != nil { tt.checkAddr(t, ep) } }) } } // TestNetstackRouting tests that packets from netstack endpoints are routed back through netstack func TestNetstackRouting(t *testing.T) { // Create the SharedBind with a physical UDP socket physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create physical UDP connection: %v", err) } sharedBind, err := New(physicalConn) if err != nil { t.Fatalf("Failed to create SharedBind: %v", err) } defer sharedBind.Close() // Create a mock "netstack" connection (just another UDP socket for testing) netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create netstack UDP connection: %v", err) } defer netstackConn.Close() // Set the netstack connection sharedBind.SetNetstackConn(netstackConn) // Create a "client" that would receive packets clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create client UDP connection: %v", err) } defer clientConn.Close() clientAddr := clientConn.LocalAddr().(*net.UDPAddr) clientAddrPort := clientAddr.AddrPort() // Inject a packet from the "netstack" source - this should track the endpoint testData := []byte("test packet from netstack") err = sharedBind.InjectPacket(testData, clientAddrPort) if err != nil { t.Fatalf("InjectPacket failed: %v", err) } // Now when we send a response to this endpoint, it should go through netstack endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort} responseData := []byte("response packet") err = sharedBind.Send([][]byte{responseData}, endpoint) if err != nil { t.Fatalf("Send failed: %v", err) } // The packet should be received by the client from the netstack connection buf := make([]byte, 1024) clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) n, fromAddr, err := clientConn.ReadFromUDP(buf) if err != nil { t.Fatalf("Failed to receive response: %v", err) } if string(buf[:n]) != string(responseData) { t.Errorf("Expected to receive %q, got %q", responseData, buf[:n]) } // Verify the response came from the netstack connection, not the physical one netstackAddr := netstackConn.LocalAddr().(*net.UDPAddr) if fromAddr.Port != netstackAddr.Port { t.Errorf("Expected response from netstack port %d, got %d", netstackAddr.Port, fromAddr.Port) } } // TestSocketRouting tests that packets from socket endpoints are routed through socket func TestSocketRouting(t *testing.T) { // Create the SharedBind with a physical UDP socket physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create physical UDP connection: %v", err) } sharedBind, err := New(physicalConn) if err != nil { t.Fatalf("Failed to create SharedBind: %v", err) } defer sharedBind.Close() // Create a mock "netstack" connection netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create netstack UDP connection: %v", err) } defer netstackConn.Close() // Set the netstack connection sharedBind.SetNetstackConn(netstackConn) // Create a "client" that would receive packets (this simulates a hole-punched client) clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create client UDP connection: %v", err) } defer clientConn.Close() clientAddr := clientConn.LocalAddr().(*net.UDPAddr) clientAddrPort := clientAddr.AddrPort() // Don't inject from netstack - this endpoint is NOT tracked as netstack-sourced // So Send should use the physical socket endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort} responseData := []byte("response packet via socket") err = sharedBind.Send([][]byte{responseData}, endpoint) if err != nil { t.Fatalf("Send failed: %v", err) } // The packet should be received by the client from the physical connection buf := make([]byte, 1024) clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) n, fromAddr, err := clientConn.ReadFromUDP(buf) if err != nil { t.Fatalf("Failed to receive response: %v", err) } if string(buf[:n]) != string(responseData) { t.Errorf("Expected to receive %q, got %q", responseData, buf[:n]) } // Verify the response came from the physical connection, not the netstack one physicalAddr := physicalConn.LocalAddr().(*net.UDPAddr) if fromAddr.Port != physicalAddr.Port { t.Errorf("Expected response from physical port %d, got %d", physicalAddr.Port, fromAddr.Port) } } // TestClearNetstackConn tests that clearing the netstack connection works correctly func TestClearNetstackConn(t *testing.T) { physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create physical UDP connection: %v", err) } sharedBind, err := New(physicalConn) if err != nil { t.Fatalf("Failed to create SharedBind: %v", err) } defer sharedBind.Close() // Set a netstack connection netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) if err != nil { t.Fatalf("Failed to create netstack UDP connection: %v", err) } defer netstackConn.Close() sharedBind.SetNetstackConn(netstackConn) // Inject a packet to track an endpoint testAddrPort := netip.MustParseAddrPort("192.168.1.100:51820") err = sharedBind.InjectPacket([]byte("test"), testAddrPort) if err != nil { t.Fatalf("InjectPacket failed: %v", err) } // Verify the endpoint is tracked _, tracked := sharedBind.netstackEndpoints.Load(testAddrPort.String()) if !tracked { t.Error("Expected endpoint to be tracked as netstack-sourced") } // Clear the netstack connection sharedBind.ClearNetstackConn() // Verify the netstack connection is cleared if sharedBind.GetNetstackConn() != nil { t.Error("Expected netstack connection to be nil after clear") } // Verify the tracked endpoints are cleared _, stillTracked := sharedBind.netstackEndpoints.Load(testAddrPort.String()) if stillTracked { t.Error("Expected endpoint tracking to be cleared") } }