diff --git a/iface/bind/bind.go b/iface/bind/bind.go index 100449d31..84a1660f3 100644 --- a/iface/bind/bind.go +++ b/iface/bind/bind.go @@ -1,189 +1,321 @@ package bind import ( + "context" "errors" "fmt" "net" "net/netip" + "strconv" "sync" "syscall" - "github.com/pion/stun" "github.com/pion/transport/v2" - log "github.com/sirupsen/logrus" - "golang.zx2c4.com/wireguard/conn" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + wgConn "golang.zx2c4.com/wireguard/conn" ) -// ICEBind is the userspace implementation of WireGuard's conn.Bind interface using ice.UDPMux of the pion/ice library -type ICEBind struct { - // below fields, initialized on open - ipv4 net.PacketConn - udpMux *UniversalUDPMuxDefault +var ( + _ wgConn.Bind = (*ICEBind)(nil) +) - // below are fields initialized on creation +// ICEBind implements Bind for all platforms except Windows. +type ICEBind struct { + mu sync.Mutex // protects following fields + ipv4 *net.UDPConn + ipv6 *net.UDPConn + blackhole4 bool + blackhole6 bool + ipv4PC *ipv4.PacketConn + ipv6PC *ipv6.PacketConn + batchSize int + udpAddrPool sync.Pool + ipv4MsgsPool sync.Pool + ipv6MsgsPool sync.Pool + + // NetBird related variables transportNet transport.Net - mu sync.Mutex + udpMux *UniversalUDPMuxDefault } -// NewICEBind create a new instance of ICEBind with a given transportNet function. -// The transportNet can be nil. func NewICEBind(transportNet transport.Net) *ICEBind { return &ICEBind{ - transportNet: transportNet, - mu: sync.Mutex{}, - } -} + batchSize: wgConn.DefaultBatchSize, -// GetICEMux returns the ICE UDPMux that was created and used by ICEBind -func (b *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { - b.mu.Lock() - defer b.mu.Unlock() - if b.udpMux == nil { - return nil, fmt.Errorf("ICEBind has not been initialized yet") - } - - return b.udpMux, nil -} - -// Open creates a WireGuard socket and an instance of UDPMux that is used to glue up ICE and WireGuard for hole punching -func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { - b.mu.Lock() - defer b.mu.Unlock() - - if b.ipv4 != nil { - return nil, 0, conn.ErrBindAlreadyOpen - } - - var err error - b.ipv4, _, err = listenNet("udp4", int(uport)) - if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - return nil, 0, err - } - - b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.ipv4, Net: b.transportNet}) - - portAddr, err := netip.ParseAddrPort(b.ipv4.LocalAddr().String()) - if err != nil { - return nil, 0, err - } - - log.Infof("opened ICEBind on %s", b.ipv4.LocalAddr().String()) - - return []conn.ReceiveFunc{ - b.makeReceiveIPv4(b.ipv4), + udpAddrPool: sync.Pool{ + New: func() any { + return &net.UDPAddr{ + IP: make([]byte, 16), + } + }, }, - portAddr.Port(), nil + + ipv4MsgsPool: sync.Pool{ + New: func() any { + msgs := make([]ipv4.Message, wgConn.DefaultBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, srcControlSize) + } + return &msgs + }, + }, + + ipv6MsgsPool: sync.Pool{ + New: func() any { + msgs := make([]ipv6.Message, wgConn.DefaultBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, srcControlSize) + } + return &msgs + }, + }, + transportNet: transportNet, + } } -func listenNet(network string, port int) (net.PacketConn, int, error) { - c, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) +var ( + _ wgConn.Bind = (*ICEBind)(nil) + _ wgConn.Endpoint = &wgConn.StdNetEndpoint{} +) + +func (*ICEBind) ParseEndpoint(s string) (wgConn.Endpoint, error) { + e, err := netip.ParseAddrPort(s) + return asEndpoint(e), err +} + +func listenNet(network string, port int) (*net.UDPConn, int, error) { + conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } - lAddr := c.LocalAddr() - uAddr, err := net.ResolveUDPAddr( - lAddr.Network(), - lAddr.String(), + // Retrieve port. + laddr := conn.LocalAddr() + uaddr, err := net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), ) if err != nil { return nil, 0, err } - return c, uAddr.Port, nil + return conn.(*net.UDPConn), uaddr.Port, nil } -func parseSTUNMessage(raw []byte) (*stun.Message, error) { - msg := &stun.Message{ - Raw: raw, - } - if err := msg.Decode(); err != nil { - return nil, err +func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var err error + var tries int + + if s.ipv4 != nil || s.ipv6 != nil { + return nil, 0, wgConn.ErrBindAlreadyOpen } - return msg, nil + // Attempt to open ipv4 and ipv6 listeners on the same port. + // If uport is 0, we can retry on failure. +again: + port := int(uport) + var v4conn, v6conn *net.UDPConn + + v4conn, port, err = listenNet("udp4", port) + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + return nil, 0, err + } + + // Listen on the same port as we're using for ipv4. + v6conn, port, err = listenNet("udp6", port) + if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { + v4conn.Close() + tries++ + goto again + } + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + v4conn.Close() + return nil, 0, err + } + var fns []wgConn.ReceiveFunc + if v4conn != nil { + fns = append(fns, s.receiveIPv4) + s.ipv4 = v4conn + } + if v6conn != nil { + fns = append(fns, s.receiveIPv6) + s.ipv6 = v6conn + } + if len(fns) == 0 { + return nil, 0, syscall.EAFNOSUPPORT + } + + s.ipv4PC = ipv4.NewPacketConn(s.ipv4) + s.ipv6PC = ipv6.NewPacketConn(s.ipv6) + + s.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: s.ipv4, Net: s.transportNet}) + return fns, uint16(port), nil } -func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc { - return func(buff []byte) (int, conn.Endpoint, error) { - n, endpoint, err := c.ReadFrom(buff) - if err != nil { - return 0, nil, err - } - e, err := netip.ParseAddrPort(endpoint.String()) - if err != nil { - return 0, nil, err - } - if !stun.IsMessage(buff) { - // WireGuard traffic - return n, (conn.StdNetEndpoint)(netip.AddrPortFrom(e.Addr(), e.Port())), nil - } - - msg, err := parseSTUNMessage(buff[:n]) - if err != nil { - return 0, nil, err - } - - err = b.udpMux.HandleSTUNMessage(msg, endpoint) - if err != nil { - log.Warnf("failed to handle packet") - } - - // discard packets because they are STUN related - return 0, nil, nil //todo proper return +func (s *ICEBind) receiveIPv4(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { + msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) + defer s.ipv4MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] } + numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep + } + return numMsgs, nil } -// Close closes the WireGuard socket and UDPMux -func (b *ICEBind) Close() error { - b.mu.Lock() - defer b.mu.Unlock() +func (s *ICEBind) receiveIPv6(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { + msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) + defer s.ipv6MsgsPool.Put(msgs) + for i := range buffs { + (*msgs)[i].Buffers[0] = buffs[i] + } + numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := asEndpoint(addrPort) + getSrcFromControl(msg.OOB, ep) + eps[i] = ep + } + return numMsgs, nil +} + +func (s *ICEBind) BatchSize() int { + return s.batchSize +} + +func (s *ICEBind) Close() error { + s.mu.Lock() + defer s.mu.Unlock() var err1, err2 error - if b.ipv4 != nil { - c := b.ipv4 - b.ipv4 = nil - err1 = c.Close() + if s.ipv4 != nil { + err1 = s.ipv4.Close() + s.ipv4 = nil } - - if b.udpMux != nil { - m := b.udpMux - b.udpMux = nil - err2 = m.Close() + if s.ipv6 != nil { + err2 = s.ipv6.Close() + s.ipv6 = nil } - + s.blackhole4 = false + s.blackhole6 = false if err1 != nil { return err1 } - return err2 } -// SetMark sets the mark for each packet sent through this Bind. -// This mark is passed to the kernel as the socket option SO_MARK. -func (b *ICEBind) SetMark(mark uint32) error { - return nil +func (s *ICEBind) Send(buffs [][]byte, endpoint wgConn.Endpoint) error { + s.mu.Lock() + blackhole := s.blackhole4 + conn := s.ipv4 + is6 := false + if endpoint.DstIP().Is6() { + blackhole = s.blackhole6 + conn = s.ipv6 + is6 = true + } + s.mu.Unlock() + + if blackhole { + return nil + } + if conn == nil { + return syscall.EAFNOSUPPORT + } + if is6 { + return s.send6(s.ipv6PC, endpoint, buffs) + } else { + return s.send4(s.ipv4PC, endpoint, buffs) + } } -// Send bytes to the remote endpoint (peer) -func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error { - - nend, ok := endpoint.(conn.StdNetEndpoint) - if !ok { - return conn.ErrWrongEndpointType +// GetICEMux returns the ICE UDPMux that was created and used by ICEBind +func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.udpMux == nil { + return nil, fmt.Errorf("ICEBind has not been initialized yet") } - addrPort := netip.AddrPort(nend) - _, err := b.ipv4.WriteTo(buff, &net.UDPAddr{ - IP: addrPort.Addr().AsSlice(), - Port: int(addrPort.Port()), - Zone: addrPort.Addr().Zone(), - }) + + return s.udpMux, nil +} + +func (s *ICEBind) send4(conn *ipv4.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error { + ua := s.udpAddrPool.Get().(*net.UDPAddr) + as4 := ep.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] + ua.Port = int(ep.(*wgConn.StdNetEndpoint).Port()) + msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) + for i, buff := range buffs { + (*msgs)[i].Buffers[0] = buff + (*msgs)[i].Addr = ua + setSrcControl(&(*msgs)[i].OOB, ep.(*wgConn.StdNetEndpoint)) + } + var ( + n int + err error + start int + ) + for { + n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break + } + start += n + } + s.udpAddrPool.Put(ua) + s.ipv4MsgsPool.Put(msgs) return err } -// ParseEndpoint creates a new endpoint from a string. -func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) { - e, err := netip.ParseAddrPort(s) - return asEndpoint(e), err +func (s *ICEBind) send6(conn *ipv6.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error { + ua := s.udpAddrPool.Get().(*net.UDPAddr) + as16 := ep.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] + ua.Port = int(ep.(*wgConn.StdNetEndpoint).Port()) + msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) + for i, buff := range buffs { + (*msgs)[i].Buffers[0] = buff + (*msgs)[i].Addr = ua + setSrcControl(&(*msgs)[i].OOB, ep.(*wgConn.StdNetEndpoint)) + } + var ( + n int + err error + start int + ) + for { + n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0) + if err != nil || n == len((*msgs)[start:len(buffs)]) { + break + } + start += n + } + s.udpAddrPool.Put(ua) + s.ipv6MsgsPool.Put(msgs) + return err } // endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint. @@ -191,17 +323,17 @@ func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) { // but Endpoints are immutable, so we can re-use them. var endpointPool = sync.Pool{ New: func() any { - return make(map[netip.AddrPort]conn.Endpoint) + return make(map[netip.AddrPort]*wgConn.StdNetEndpoint) }, } // asEndpoint returns an Endpoint containing ap. -func asEndpoint(ap netip.AddrPort) conn.Endpoint { - m := endpointPool.Get().(map[netip.AddrPort]conn.Endpoint) +func asEndpoint(ap netip.AddrPort) *wgConn.StdNetEndpoint { + m := endpointPool.Get().(map[netip.AddrPort]*wgConn.StdNetEndpoint) defer endpointPool.Put(m) e, ok := m[ap] if !ok { - e = conn.Endpoint(conn.StdNetEndpoint(ap)) + e = &wgConn.StdNetEndpoint{AddrPort: ap} m[ap] = e } return e diff --git a/iface/bind/controlfns.go b/iface/bind/controlfns.go new file mode 100644 index 000000000..17762484b --- /dev/null +++ b/iface/bind/controlfns.go @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package bind + +import ( + "net" + "syscall" +) + +// controlFn is the callback function signature from net.ListenConfig.Control. +// It is used to apply platform specific configuration to the socket prior to +// bind. +type controlFn func(network, address string, c syscall.RawConn) error + +// controlFns is a list of functions that are called from the listen config +// that can apply socket options. +var controlFns = []controlFn{} + +// listenConfig returns a net.ListenConfig that applies the controlFns to the +// socket prior to bind. This is used to apply socket buffer sizing and packet +// information OOB configuration for sticky sockets. +func listenConfig() *net.ListenConfig { + return &net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + for _, fn := range controlFns { + if err := fn(network, address, c); err != nil { + return err + } + } + return nil + }, + } +} diff --git a/iface/bind/controlfns_linux.go b/iface/bind/controlfns_linux.go new file mode 100644 index 000000000..c78f0b36c --- /dev/null +++ b/iface/bind/controlfns_linux.go @@ -0,0 +1,41 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package bind + +import ( + "fmt" + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + controlFns = append(controlFns, + + // Enable receiving of the packet information (IP_PKTINFO for IPv4, + // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support. + func(network, address string, c syscall.RawConn) error { + var err error + switch network { + case "udp4": + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) + }) + case "udp6": + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) + if err != nil { + return + } + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + default: + err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL) + } + return err + }, + ) +} diff --git a/iface/bind/controlfns_unix.go b/iface/bind/controlfns_unix.go new file mode 100644 index 000000000..4af8aa4b8 --- /dev/null +++ b/iface/bind/controlfns_unix.go @@ -0,0 +1,28 @@ +//go:build !windows && !linux && !js + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package bind + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + controlFns = append(controlFns, + func(network, address string, c syscall.RawConn) error { + var err error + if network == "udp6" { + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + } + return err + }, + ) +} diff --git a/iface/bind/mark_default.go b/iface/bind/mark_default.go new file mode 100644 index 000000000..f279a2cb7 --- /dev/null +++ b/iface/bind/mark_default.go @@ -0,0 +1,12 @@ +//go:build !linux && !openbsd && !freebsd + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package bind + +func (s *ICEBind) SetMark(mark uint32) error { + return nil +} diff --git a/iface/bind/mark_unix.go b/iface/bind/mark_unix.go new file mode 100644 index 000000000..d9e46eea7 --- /dev/null +++ b/iface/bind/mark_unix.go @@ -0,0 +1,65 @@ +//go:build linux || openbsd || freebsd + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "runtime" + + "golang.org/x/sys/unix" +) + +var fwmarkIoctl int + +func init() { + switch runtime.GOOS { + case "linux", "android": + fwmarkIoctl = 36 /* unix.SO_MARK */ + case "freebsd": + fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ + case "openbsd": + fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ + } +} + +func (s *StdNetBind) SetMark(mark uint32) error { + var operr error + if fwmarkIoctl == 0 { + return nil + } + if s.ipv4 != nil { + fd, err := s.ipv4.SyscallConn() + if err != nil { + return err + } + err = fd.Control(func(fd uintptr) { + operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) + }) + if err == nil { + err = operr + } + if err != nil { + return err + } + } + if s.ipv6 != nil { + fd, err := s.ipv6.SyscallConn() + if err != nil { + return err + } + err = fd.Control(func(fd uintptr) { + operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) + }) + if err == nil { + err = operr + } + if err != nil { + return err + } + } + return nil +} diff --git a/iface/bind/sticky_default.go b/iface/bind/sticky_default.go new file mode 100644 index 000000000..3887caf82 --- /dev/null +++ b/iface/bind/sticky_default.go @@ -0,0 +1,28 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package bind + +import wgConn "golang.zx2c4.com/wireguard/conn" + +// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but +// use alternatively named flags and need ports and require testing. + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *wgConn.StdNetEndpoint) { +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *wgConn.StdNetEndpoint) { +} + +// srcControlSize returns the recommended buffer size for pooling sticky control +// data. +const srcControlSize = 0 diff --git a/iface/bind/sticky_linux.go b/iface/bind/sticky_linux.go new file mode 100644 index 000000000..380babfa3 --- /dev/null +++ b/iface/bind/sticky_linux.go @@ -0,0 +1,111 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package bind + +import ( + "net/netip" + "unsafe" + + "golang.org/x/sys/unix" +) + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { + ep.ClearSrc() + + var ( + hdr unix.Cmsghdr + data []byte + rem []byte = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(control) + if err != nil { + return + } + + if hdr.Level == unix.IPPROTO_IP && + hdr.Type == unix.IP_PKTINFO { + + info := pktInfoFromBuf[unix.Inet4Pktinfo](data) + ep.src.Addr = netip.AddrFrom4(info.Spec_dst) + ep.src.ifidx = info.Ifindex + + return + } + + if hdr.Level == unix.IPPROTO_IPV6 && + hdr.Type == unix.IPV6_PKTINFO { + + info := pktInfoFromBuf[unix.Inet6Pktinfo](data) + ep.src.Addr = netip.AddrFrom16(info.Addr) + ep.src.ifidx = int32(info.Ifindex) + + return + } + } +} + +// pktInfoFromBuf returns type T populated from the provided buf via copy(). It +// panics if buf is of insufficient size. +func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) { + size := int(unsafe.Sizeof(t)) + if len(buf) < size { + panic("pktInfoFromBuf: buffer too small") + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf) + return t +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { + *control = (*control)[:cap(*control)] + if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { + *control = (*control)[:0] + return + } + + if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() { + *control = (*control)[:0] + return + } + + if len(*control) < srcControlSize { + *control = (*control)[:0] + return + } + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) + if ep.SrcIP().Is4() { + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) + info.Ifindex = ep.src.ifidx + if ep.SrcIP().IsValid() { + info.Spec_dst = ep.SrcIP().As4() + } + } else { + hdr.Level = unix.IPPROTO_IPV6 + hdr.Type = unix.IPV6_PKTINFO + hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo + + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) + info.Ifindex = uint32(ep.src.ifidx) + if ep.SrcIP().IsValid() { + info.Addr = ep.SrcIP().As16() + } + } + + *control = (*control)[:hdr.Len] +} + +var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo) diff --git a/iface/bind/sticky_linux_test.go b/iface/bind/sticky_linux_test.go new file mode 100644 index 000000000..8fe9b9331 --- /dev/null +++ b/iface/bind/sticky_linux_test.go @@ -0,0 +1,207 @@ +//go:build linux +// +build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package bind + +import ( + "context" + "net" + "net/netip" + "runtime" + "testing" + "unsafe" + + "golang.org/x/sys/unix" +) + +func Test_setSrcControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), + } + ep.src.Addr = netip.MustParseAddr("127.0.0.1") + ep.src.ifidx = 5 + + control := make([]byte, srcControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IP { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IP_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 { + t.Errorf("unexpected address: %v", info.Spec_dst) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("IPv6", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("[::1]:1234"), + } + ep.src.Addr = netip.MustParseAddr("::1") + ep.src.ifidx = 5 + + control := make([]byte, srcControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IPV6 { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IPV6_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if hdr.Len != uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Addr != ep.SrcIP().As16() { + t.Errorf("unexpected address: %v", info.Addr) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("ClearOnNoSrc", func(t *testing.T) { + control := make([]byte, srcControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = 1 + hdr.Type = 2 + hdr.Len = 3 + + setSrcControl(&control, &StdNetEndpoint{}) + + if len(control) != 0 { + t.Errorf("unexpected control: %v", control) + } + }) +} + +func Test_getSrcFromControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + control := make([]byte, srcControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.src.Addr) + } + if ep.src.ifidx != 5 { + t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + } + }) + t.Run("IPv6", func(t *testing.T) { + control := make([]byte, srcControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IPV6 + hdr.Type = unix.IPV6_PKTINFO + hdr.Len = uint64(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.SrcIP() != netip.MustParseAddr("::1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.src.ifidx != 5 { + t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + } + }) + t.Run("ClearOnEmpty", func(t *testing.T) { + control := make([]byte, srcControlSize) + ep := &StdNetEndpoint{} + ep.src.Addr = netip.MustParseAddr("::1") + ep.src.ifidx = 5 + + getSrcFromControl(control, ep) + if ep.SrcIP().IsValid() { + t.Errorf("unexpected address: %v", ep.src.Addr) + } + if ep.src.ifidx != 0 { + t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + } + }) +} + +func Test_listenConfig(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IP_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) + t.Run("IPv6", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") + if err != nil { + t.Fatal(err) + } + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IPV6_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) +}