Files
netbird/iface/bind/bind.go
2023-04-17 14:37:48 +02:00

341 lines
7.5 KiB
Go

package bind
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"syscall"
"github.com/pion/transport/v2"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
var (
_ wgConn.Bind = (*ICEBind)(nil)
)
// ICEBind implements Bind for all platforms except Windows.
type ICEBind struct {
mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
ipv4PC *ipv4.PacketConn
ipv6PC *ipv6.PacketConn
batchSize int
udpAddrPool sync.Pool
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
// NetBird related variables
transportNet transport.Net
udpMux *UniversalUDPMuxDefault
}
func NewICEBind(transportNet transport.Net) *ICEBind {
return &ICEBind{
batchSize: wgConn.DefaultBatchSize,
udpAddrPool: sync.Pool{
New: func() any {
return &net.UDPAddr{
IP: make([]byte, 16),
}
},
},
ipv4MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, wgConn.DefaultBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
ipv6MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, wgConn.DefaultBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
transportNet: transportNet,
}
}
var (
_ wgConn.Bind = (*ICEBind)(nil)
_ wgConn.Endpoint = &wgConn.StdNetEndpoint{}
)
func (*ICEBind) ParseEndpoint(s string) (wgConn.Endpoint, error) {
e, err := netip.ParseAddrPort(s)
return asEndpoint(e), err
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn.(*net.UDPConn), uaddr.Port, nil
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.mu.Lock()
defer s.mu.Unlock()
var err error
var tries int
if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, wgConn.ErrBindAlreadyOpen
}
// Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var v4conn, v6conn *net.UDPConn
v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
v4conn.Close()
return nil, 0, err
}
var fns []wgConn.ReceiveFunc
if v4conn != nil {
fns = append(fns, s.receiveIPv4)
s.ipv4 = v4conn
}
if v6conn != nil {
fns = append(fns, s.receiveIPv6)
s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
s.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: s.ipv4, Net: s.transportNet})
return fns, uint16(port), nil
}
func (s *ICEBind) receiveIPv4(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
defer s.ipv4MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *ICEBind) receiveIPv6(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
defer s.ipv6MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *ICEBind) BatchSize() int {
return s.batchSize
}
func (s *ICEBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err1, err2 error
if s.ipv4 != nil {
err1 = s.ipv4.Close()
s.ipv4 = nil
}
if s.ipv6 != nil {
err2 = s.ipv6.Close()
s.ipv6 = nil
}
s.blackhole4 = false
s.blackhole6 = false
if err1 != nil {
return err1
}
return err2
}
func (s *ICEBind) Send(buffs [][]byte, endpoint wgConn.Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
is6 = true
}
s.mu.Unlock()
if blackhole {
return nil
}
if conn == nil {
return syscall.EAFNOSUPPORT
}
if is6 {
return s.send6(s.ipv6PC, endpoint, buffs)
} else {
return s.send4(s.ipv4PC, endpoint, buffs)
}
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (s *ICEBind) send4(conn *ipv4.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
ua.Port = int(ep.(*wgConn.StdNetEndpoint).Port())
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
for i, buff := range buffs {
(*msgs)[i].Buffers[0] = buff
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*wgConn.StdNetEndpoint))
}
var (
n int
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err
}
func (s *ICEBind) send6(conn *ipv6.PacketConn, ep wgConn.Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
ua.Port = int(ep.(*wgConn.StdNetEndpoint).Port())
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
for i, buff := range buffs {
(*msgs)[i].Buffers[0] = buff
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*wgConn.StdNetEndpoint))
}
var (
n int
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
return err
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
// This exists to reduce allocations: Putting a netip.AddrPort in an Endpoint allocates,
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
return make(map[netip.AddrPort]*wgConn.StdNetEndpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) *wgConn.StdNetEndpoint {
m := endpointPool.Get().(map[netip.AddrPort]*wgConn.StdNetEndpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
e = &wgConn.StdNetEndpoint{AddrPort: ap}
m[ap] = e
}
return e
}