Handle IPv6 candidates in userspace bind

This commit is contained in:
Viktor Liu
2025-07-09 12:22:30 +02:00
parent 90c5244a40
commit 4cff2a09a3
8 changed files with 253 additions and 22 deletions

View File

@@ -0,0 +1,153 @@
package bind
import (
"errors"
"net"
"time"
"github.com/hashicorp/go-multierror"
nberrors "github.com/netbirdio/netbird/client/errors"
)
var (
errNoIPv4Conn = errors.New("no IPv4 connection available")
errNoIPv6Conn = errors.New("no IPv6 connection available")
errInvalidAddr = errors.New("invalid address type")
)
// DualStackPacketConn is a composite PacketConn that can handle both IPv4 and IPv6
type DualStackPacketConn struct {
ipv4Conn net.PacketConn
ipv6Conn net.PacketConn
}
// NewDualStackPacketConn creates a new dual-stack packet connection
func NewDualStackPacketConn(ipv4Conn, ipv6Conn net.PacketConn) *DualStackPacketConn {
return &DualStackPacketConn{
ipv4Conn: ipv4Conn,
ipv6Conn: ipv6Conn,
}
}
// ReadFrom reads from both IPv4 and IPv6 connections
func (d *DualStackPacketConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) {
// Prefer IPv4 if available
if d.ipv4Conn != nil {
return d.ipv4Conn.ReadFrom(b)
}
if d.ipv6Conn != nil {
return d.ipv6Conn.ReadFrom(b)
}
return 0, nil, net.ErrClosed
}
// WriteTo writes to the appropriate connection based on the address type
func (d *DualStackPacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return 0, &net.OpError{
Op: "write",
Net: "udp",
Addr: addr,
Err: errInvalidAddr,
}
}
if udpAddr.IP.To4() == nil {
if d.ipv6Conn != nil {
return d.ipv6Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp6",
Addr: addr,
Err: errNoIPv6Conn,
}
}
if d.ipv4Conn != nil {
return d.ipv4Conn.WriteTo(b, addr)
}
return 0, &net.OpError{
Op: "write",
Net: "udp4",
Addr: addr,
Err: errNoIPv4Conn,
}
}
// Close closes both connections
func (d *DualStackPacketConn) Close() error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.Close(); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// LocalAddr returns the local address of the IPv4 connection (for compatibility)
func (d *DualStackPacketConn) LocalAddr() net.Addr {
if d.ipv4Conn != nil {
return d.ipv4Conn.LocalAddr()
}
if d.ipv6Conn != nil {
return d.ipv6Conn.LocalAddr()
}
return nil
}
// SetDeadline sets the deadline for both connections
func (d *DualStackPacketConn) SetDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetReadDeadline sets the read deadline for both connections
func (d *DualStackPacketConn) SetReadDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetReadDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}
// SetWriteDeadline sets the write deadline for both connections
func (d *DualStackPacketConn) SetWriteDeadline(t time.Time) error {
var result *multierror.Error
if d.ipv4Conn != nil {
if err := d.ipv4Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
if d.ipv6Conn != nil {
if err := d.ipv6Conn.SetWriteDeadline(t); err != nil {
result = multierror.Append(result, err)
}
}
return nberrors.FormatErrorOrNil(result)
}

View File

@@ -2,7 +2,7 @@ package bind
import (
"encoding/binary"
"fmt"
"errors"
"net"
"net/netip"
"runtime"
@@ -26,7 +26,7 @@ type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateReceiverFn(pc wgConn.PacketReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createReceiverFn(pc, conn, rxOffload, msgPool)
}
@@ -53,6 +53,8 @@ type ICEBind struct {
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
ipv4Conn *net.UDPConn
ipv6Conn *net.UDPConn
address wgaddr.Address
activityRecorder *ActivityRecorder
}
@@ -110,11 +112,11 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
if s.udpMux != nil {
return s.udpMux, nil
}
return nil, errors.New("ICEBind has not been initialized yet")
}
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
@@ -146,14 +148,40 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (s *ICEBind) createReceiverFn(pc wgConn.PacketReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
func (s *ICEBind) createReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
localAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
log.Errorf("ICEBind: unexpected address type: %T", conn.LocalAddr())
return nil
}
isIPv6 := localAddr.IP.To4() == nil
if isIPv6 {
s.ipv6Conn = conn
} else {
s.ipv4Conn = conn
}
needsNewMux := s.udpMux == nil && (s.ipv4Conn != nil || s.ipv6Conn != nil)
needsUpgrade := s.udpMux != nil && s.ipv4Conn != nil && s.ipv6Conn != nil
if needsNewMux || needsUpgrade {
var iceMuxConn net.PacketConn
switch {
case s.ipv4Conn != nil && s.ipv6Conn != nil:
iceMuxConn = NewDualStackPacketConn(s.ipv4Conn, s.ipv6Conn)
case s.ipv4Conn != nil:
iceMuxConn = s.ipv4Conn
default:
iceMuxConn = s.ipv6Conn
}
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
UDPConn: iceMuxConn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
@@ -198,7 +226,7 @@ func (s *ICEBind) createReceiverFn(pc wgConn.PacketReader, conn *net.UDPConn, rx
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr, isIPv6)
if ok {
continue
}
@@ -206,7 +234,12 @@ func (s *ICEBind) createReceiverFn(pc wgConn.PacketReader, conn *net.UDPConn, rx
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
udpAddr, ok := msg.Addr.(*net.UDPAddr)
if !ok {
log.Errorf("ICEBind: unexpected address type: %T", msg.Addr)
continue
}
addrPort := udpAddr.AddrPort()
if isTransportPkg(msg.Buffers, msg.N) {
s.activityRecorder.record(addrPort)
@@ -220,7 +253,7 @@ func (s *ICEBind) createReceiverFn(pc wgConn.PacketReader, conn *net.UDPConn, rx
}
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr, isIPv6 bool) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
@@ -232,9 +265,10 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr)
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
if s.udpMux != nil {
if err := s.udpMux.HandleSTUNMessage(msg, addr); err != nil {
log.Warnf("failed to handle STUN packet: %v", err)
}
}
buffers[i] = []byte{}

View File

@@ -342,6 +342,9 @@ func (m *UDPMuxDefault) Close() error {
}
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
if dualStackConn, ok := m.params.UDPConn.(*DualStackPacketConn); ok {
return dualStackConn.WriteTo(buf, rAddr)
}
return m.params.UDPConn.WriteTo(buf, rAddr)
}

View File

@@ -126,6 +126,11 @@ type udpConn struct {
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
// Check if this is a dual-stack connection and handle IPv6 addresses properly
if dualStackConn, ok := u.PacketConn.(*DualStackPacketConn); ok {
return dualStackConn.WriteTo(b, addr)
}
if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr)
}
@@ -141,6 +146,11 @@ func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (i
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
if dualStackConn, ok := u.PacketConn.(*DualStackPacketConn); ok {
return dualStackConn.WriteTo(b, addr)
}
return u.PacketConn.WriteTo(b, addr)
}
@@ -148,6 +158,11 @@ func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
if dualStackConn, ok := u.PacketConn.(*DualStackPacketConn); ok {
return dualStackConn.WriteTo(b, addr)
}
return u.PacketConn.WriteTo(b, addr)
}