mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
341 lines
7.5 KiB
Go
341 lines
7.5 KiB
Go
package bind
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"strconv"
|
|
"sync"
|
|
"syscall"
|
|
|
|
"github.com/pion/transport/v2"
|
|
"golang.org/x/net/ipv4"
|
|
"golang.org/x/net/ipv6"
|
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
|
)
|
|
|
|
var (
|
|
_ wgConn.Bind = (*ICEBind)(nil)
|
|
)
|
|
|
|
// ICEBind implements Bind for all platforms except Windows.
|
|
type ICEBind struct {
|
|
mu sync.Mutex // protects following fields
|
|
ipv4 *net.UDPConn
|
|
ipv6 *net.UDPConn
|
|
blackhole4 bool
|
|
blackhole6 bool
|
|
ipv4PC *ipv4.PacketConn
|
|
ipv6PC *ipv6.PacketConn
|
|
batchSize int
|
|
udpAddrPool sync.Pool
|
|
ipv4MsgsPool sync.Pool
|
|
ipv6MsgsPool sync.Pool
|
|
|
|
// NetBird related variables
|
|
transportNet transport.Net
|
|
udpMux *UniversalUDPMuxDefault
|
|
}
|
|
|
|
func NewICEBind(transportNet transport.Net) *ICEBind {
|
|
return &ICEBind{
|
|
batchSize: wgConn.DefaultBatchSize,
|
|
|
|
udpAddrPool: sync.Pool{
|
|
New: func() any {
|
|
return &net.UDPAddr{
|
|
IP: make([]byte, 16),
|
|
}
|
|
},
|
|
},
|
|
|
|
ipv4MsgsPool: sync.Pool{
|
|
New: func() any {
|
|
msgs := make([]ipv4.Message, wgConn.DefaultBatchSize)
|
|
for i := range msgs {
|
|
msgs[i].Buffers = make(net.Buffers, 1)
|
|
msgs[i].OOB = make([]byte, srcControlSize)
|
|
}
|
|
return &msgs
|
|
},
|
|
},
|
|
|
|
ipv6MsgsPool: sync.Pool{
|
|
New: func() any {
|
|
msgs := make([]ipv6.Message, wgConn.DefaultBatchSize)
|
|
for i := range msgs {
|
|
msgs[i].Buffers = make(net.Buffers, 1)
|
|
msgs[i].OOB = make([]byte, srcControlSize)
|
|
}
|
|
return &msgs
|
|
},
|
|
},
|
|
transportNet: transportNet,
|
|
}
|
|
}
|
|
|
|
var (
|
|
_ wgConn.Bind = (*ICEBind)(nil)
|
|
_ wgConn.Endpoint = &wgConn.StdNetEndpoint{}
|
|
)
|
|
|
|
func (*ICEBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
|
|
e, err := netip.ParseAddrPort(s)
|
|
return asEndpoint(e), err
|
|
}
|
|
|
|
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
|
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Retrieve port.
|
|
laddr := conn.LocalAddr()
|
|
uaddr, err := net.ResolveUDPAddr(
|
|
laddr.Network(),
|
|
laddr.String(),
|
|
)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return conn.(*net.UDPConn), uaddr.Port, nil
|
|
}
|
|
|
|
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
var err error
|
|
var tries int
|
|
|
|
if s.ipv4 != nil || s.ipv6 != nil {
|
|
return nil, 0, wgConn.ErrBindAlreadyOpen
|
|
}
|
|
|
|
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
|
// If uport is 0, we can retry on failure.
|
|
again:
|
|
port := int(uport)
|
|
var v4conn, v6conn *net.UDPConn
|
|
|
|
v4conn, port, err = listenNet("udp4", port)
|
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
return nil, 0, err
|
|
}
|
|
|
|
// Listen on the same port as we're using for ipv4.
|
|
v6conn, port, err = listenNet("udp6", port)
|
|
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
|
v4conn.Close()
|
|
tries++
|
|
goto again
|
|
}
|
|
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
v4conn.Close()
|
|
return nil, 0, err
|
|
}
|
|
var fns []wgConn.ReceiveFunc
|
|
if v4conn != nil {
|
|
fns = append(fns, s.receiveIPv4)
|
|
s.ipv4 = v4conn
|
|
}
|
|
if v6conn != nil {
|
|
fns = append(fns, s.receiveIPv6)
|
|
s.ipv6 = v6conn
|
|
}
|
|
if len(fns) == 0 {
|
|
return nil, 0, syscall.EAFNOSUPPORT
|
|
}
|
|
|
|
s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
|
|
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
|
|
|
|
s.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: s.ipv4, Net: s.transportNet})
|
|
return fns, uint16(port), nil
|
|
}
|
|
|
|
func (s *ICEBind) receiveIPv4(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
|
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
|
defer s.ipv4MsgsPool.Put(msgs)
|
|
for i := range buffs {
|
|
(*msgs)[i].Buffers[0] = buffs[i]
|
|
}
|
|
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
for i := 0; i < numMsgs; i++ {
|
|
msg := &(*msgs)[i]
|
|
sizes[i] = msg.N
|
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
ep := asEndpoint(addrPort)
|
|
getSrcFromControl(msg.OOB, ep)
|
|
eps[i] = ep
|
|
}
|
|
return numMsgs, nil
|
|
}
|
|
|
|
func (s *ICEBind) receiveIPv6(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
|
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
|
defer s.ipv6MsgsPool.Put(msgs)
|
|
for i := range buffs {
|
|
(*msgs)[i].Buffers[0] = buffs[i]
|
|
}
|
|
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
for i := 0; i < numMsgs; i++ {
|
|
msg := &(*msgs)[i]
|
|
sizes[i] = msg.N
|
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
ep := asEndpoint(addrPort)
|
|
getSrcFromControl(msg.OOB, ep)
|
|
eps[i] = ep
|
|
}
|
|
return numMsgs, nil
|
|
}
|
|
|
|
func (s *ICEBind) BatchSize() int {
|
|
return s.batchSize
|
|
}
|
|
|
|
func (s *ICEBind) Close() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
var err1, err2 error
|
|
if s.ipv4 != nil {
|
|
err1 = s.ipv4.Close()
|
|
s.ipv4 = nil
|
|
}
|
|
if s.ipv6 != nil {
|
|
err2 = s.ipv6.Close()
|
|
s.ipv6 = nil
|
|
}
|
|
s.blackhole4 = false
|
|
s.blackhole6 = false
|
|
if err1 != nil {
|
|
return err1
|
|
}
|
|
return err2
|
|
}
|
|
|
|
func (s *ICEBind) Send(buffs [][]byte, endpoint wgConn.Endpoint) error {
|
|
s.mu.Lock()
|
|
blackhole := s.blackhole4
|
|
conn := s.ipv4
|
|
is6 := false
|
|
if endpoint.DstIP().Is6() {
|
|
blackhole = s.blackhole6
|
|
conn = s.ipv6
|
|
is6 = true
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
if blackhole {
|
|
return nil
|
|
}
|
|
if conn == nil {
|
|
return syscall.EAFNOSUPPORT
|
|
}
|
|
if is6 {
|
|
return s.send6(s.ipv6PC, endpoint, buffs)
|
|
} else {
|
|
return s.send4(s.ipv4PC, endpoint, buffs)
|
|
}
|
|
}
|
|
|
|
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
|
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.udpMux == nil {
|
|
return nil, fmt.Errorf("ICEBind has not been initialized yet")
|
|
}
|
|
|
|
return s.udpMux, nil
|
|
}
|
|
|
|
func (s *ICEBind) send4(conn *ipv4.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
|
|
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
|
as4 := ep.DstIP().As4()
|
|
copy(ua.IP, as4[:])
|
|
ua.IP = ua.IP[:4]
|
|
ua.Port = int(ep.(*wgConn.StdNetEndpoint).Port())
|
|
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
|
|
for i, buff := range buffs {
|
|
(*msgs)[i].Buffers[0] = buff
|
|
(*msgs)[i].Addr = ua
|
|
setSrcControl(&(*msgs)[i].OOB, ep.(*wgConn.StdNetEndpoint))
|
|
}
|
|
var (
|
|
n int
|
|
err error
|
|
start int
|
|
)
|
|
for {
|
|
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
|
|
if err != nil || n == len((*msgs)[start:len(buffs)]) {
|
|
break
|
|
}
|
|
start += n
|
|
}
|
|
s.udpAddrPool.Put(ua)
|
|
s.ipv4MsgsPool.Put(msgs)
|
|
return err
|
|
}
|
|
|
|
func (s *ICEBind) send6(conn *ipv6.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
|
|
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
|
as16 := ep.DstIP().As16()
|
|
copy(ua.IP, as16[:])
|
|
ua.IP = ua.IP[:16]
|
|
ua.Port = int(ep.(*wgConn.StdNetEndpoint).Port())
|
|
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
|
|
for i, buff := range buffs {
|
|
(*msgs)[i].Buffers[0] = buff
|
|
(*msgs)[i].Addr = ua
|
|
setSrcControl(&(*msgs)[i].OOB, ep.(*wgConn.StdNetEndpoint))
|
|
}
|
|
var (
|
|
n int
|
|
err error
|
|
start int
|
|
)
|
|
for {
|
|
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
|
|
if err != nil || n == len((*msgs)[start:len(buffs)]) {
|
|
break
|
|
}
|
|
start += n
|
|
}
|
|
s.udpAddrPool.Put(ua)
|
|
s.ipv6MsgsPool.Put(msgs)
|
|
return err
|
|
}
|
|
|
|
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
|
|
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
|
|
// but Endpoints are immutable, so we can re-use them.
|
|
var endpointPool = sync.Pool{
|
|
New: func() any {
|
|
return make(map[netip.AddrPort]*wgConn.StdNetEndpoint)
|
|
},
|
|
}
|
|
|
|
// asEndpoint returns an Endpoint containing ap.
|
|
func asEndpoint(ap netip.AddrPort) *wgConn.StdNetEndpoint {
|
|
m := endpointPool.Get().(map[netip.AddrPort]*wgConn.StdNetEndpoint)
|
|
defer endpointPool.Put(m)
|
|
e, ok := m[ap]
|
|
if !ok {
|
|
e = &wgConn.StdNetEndpoint{AddrPort: ap}
|
|
m[ap] = e
|
|
}
|
|
return e
|
|
}
|