Files
netbird/client/iface/wgproxy/bind/proxy.go
Zoltan Papp 05af39a69b [client] Add IPv6 support to UDP WireGuard proxy (#5169)
* Add IPv6 support to UDP WireGuard proxy

Add IPv6 packet header support in UDP raw socket proxy
to handle both IPv4 and IPv6 source addresses.
Refactor error handling in proxy bind implementations
to validate endpoints before acquiring locks.
2026-01-26 14:03:32 +01:00

228 lines
5.1 KiB
Go

package bind
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
type Bind interface {
SetEndpoint(addr netip.Addr, conn net.Conn)
RemoveEndpoint(addr netip.Addr)
ReceiveFromEndpoint(ctx context.Context, ep *bind.Endpoint, buf []byte)
}
type ProxyBind struct {
bind Bind
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
wgRelayedEndpoint *bind.Endpoint
wgCurrentUsed *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
paused bool
pausedCond *sync.Cond
isStarted bool
closeListener *listener.CloseListener
mtu uint16
}
func NewProxyBind(bind Bind, mtu uint16) *ProxyBind {
p := &ProxyBind{
bind: bind,
closeListener: listener.NewCloseListener(),
pausedCond: sync.NewCond(&sync.Mutex{}),
mtu: mtu + bufsize.WGBufferOverhead,
}
return p
}
// AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration.
//
// Parameters:
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
// - remoteConn: The established TURN connection to the remote peer
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
fakeNetIP, err := fakeAddress(nbAddr)
if err != nil {
return err
}
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
return nil
}
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
}
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyBind) Work() {
if p.remoteConn == nil {
return
}
p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = p.wgRelayedEndpoint
// Start the proxy only once
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func (p *ProxyBind) Pause() {
if p.remoteConn == nil {
return
}
p.pausedCond.L.Lock()
p.paused = true
p.pausedCond.L.Unlock()
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
ep, err := addrToEndpoint(endpoint)
if err != nil {
log.Errorf("failed to start package redirection: %v", err)
return
}
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = ep
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
}
func (p *ProxyBind) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil
}
p.closeMu.Lock()
defer p.closeMu.Unlock()
if p.closed {
return nil
}
p.closeListener.SetCloseListener(nil)
p.closed = true
p.cancel()
p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.Signal()
p.pausedCond.L.Unlock()
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
}
return nil
}
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("failed to close remote conn: %s", err)
}
}()
for {
buf := make([]byte, p.mtu)
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
p.closeListener.Notify()
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
p.pausedCond.L.Lock()
for p.paused {
p.pausedCond.Wait()
}
p.bind.ReceiveFromEndpoint(ctx, p.wgCurrentUsed, buf[:n])
p.pausedCond.L.Unlock()
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
if err != nil {
return nil, fmt.Errorf("parse new IP: %w", err)
}
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
}
func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) {
if addr == nil {
return nil, fmt.Errorf("invalid address")
}
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return nil, fmt.Errorf("convert %s to netip.Addr", addr)
}
addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}, nil
}