mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-06 17:08:53 +00:00
Use offload
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package bind
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -12,6 +13,8 @@ import (
|
||||
"github.com/pion/transport/v3"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
@@ -24,8 +27,12 @@ type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
|
||||
const (
|
||||
udpSegmentMaxDatagrams = 64
|
||||
)
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload)
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
@@ -51,6 +58,7 @@ type ICEBind struct {
|
||||
|
||||
muUDPMux sync.Mutex
|
||||
udpMux *UniversalUDPMuxDefault
|
||||
msgsPool sync.Pool
|
||||
}
|
||||
|
||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||
@@ -63,11 +71,24 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
||||
endpoints: make(map[netip.Addr]net.Conn),
|
||||
closedChan: make(chan struct{}),
|
||||
closed: true,
|
||||
msgsPool: sync.Pool{
|
||||
New: func() any {
|
||||
// ipv6.Message and ipv4.Message are interchangeable as they are
|
||||
// both aliases for x/net/internal/socket.Message.
|
||||
msgs := make([]ipv6.Message, wgConn.IdealBatchSize)
|
||||
for i := range msgs {
|
||||
msgs[i].Buffers = make(net.Buffers, 1)
|
||||
msgs[i].OOB = make([]byte, 0, wgConn.StickyControlSize+unix.CmsgSpace(2))
|
||||
}
|
||||
return &msgs
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
rc := receiverCreator{
|
||||
ib,
|
||||
}
|
||||
|
||||
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
|
||||
return ib
|
||||
}
|
||||
@@ -154,7 +175,7 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
|
||||
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) wgConn.ReceiveFunc {
|
||||
s.muUDPMux.Lock()
|
||||
defer s.muUDPMux.Unlock()
|
||||
|
||||
@@ -165,44 +186,83 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
|
||||
FilterFn: s.filterFn,
|
||||
},
|
||||
)
|
||||
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
|
||||
defer ipv4MsgsPool.Put(msgs)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ICEBind) receiveIP(
|
||||
br batchReader,
|
||||
conn *net.UDPConn,
|
||||
rxOffload bool,
|
||||
bufs [][]byte,
|
||||
sizes []int,
|
||||
eps []wgConn.Endpoint) (n int, err error) {
|
||||
|
||||
msgs := s.msgsPool.Get().(*[]ipv6.Message)
|
||||
for i := range bufs {
|
||||
(*msgs)[i].Buffers[0] = bufs[i]
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
||||
}
|
||||
defer func() {
|
||||
for i := range *msgs {
|
||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
||||
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
||||
}
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" {
|
||||
numMsgs, err = pc.ReadBatch(*msgs, 0)
|
||||
s.msgsPool.Put(msgs)
|
||||
}()
|
||||
var numMsgs int
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
if rxOffload {
|
||||
readAt := len(*msgs) - (wgConn.IdealBatchSize / udpSegmentMaxDatagrams)
|
||||
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs, err = splitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
numMsgs, err = br.ReadBatch(*msgs, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
|
||||
// todo: handle err
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if ok {
|
||||
sizes[i] = 0
|
||||
} else {
|
||||
sizes[i] = msg.N
|
||||
}
|
||||
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
} else {
|
||||
msg := &(*msgs)[0]
|
||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return numMsgs, nil
|
||||
numMsgs = 1
|
||||
}
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
msg := &(*msgs)[i]
|
||||
|
||||
// todo: handle err
|
||||
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
|
||||
if ok {
|
||||
sizes[i] = 0
|
||||
} else {
|
||||
sizes[i] = msg.N
|
||||
}
|
||||
|
||||
sizes[i] = msg.N
|
||||
if sizes[i] == 0 {
|
||||
continue
|
||||
}
|
||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||
eps[i] = ep
|
||||
}
|
||||
return numMsgs, nil
|
||||
}
|
||||
|
||||
type batchReader interface {
|
||||
ReadBatch([]ipv6.Message, int) (int, error)
|
||||
}
|
||||
|
||||
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
|
||||
@@ -273,3 +333,49 @@ func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
||||
}
|
||||
return newAddr, nil
|
||||
}
|
||||
|
||||
type getGSOFunc func(control []byte) (int, error)
|
||||
|
||||
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
|
||||
for i := firstMsgAt; i < len(msgs); i++ {
|
||||
msg := &msgs[i]
|
||||
if msg.N == 0 {
|
||||
return n, err
|
||||
}
|
||||
var (
|
||||
gsoSize int
|
||||
start int
|
||||
end = msg.N
|
||||
numToSplit = 1
|
||||
)
|
||||
gsoSize, err = getGSO(msg.OOB[:msg.NN])
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
if gsoSize > 0 {
|
||||
numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
||||
end = gsoSize
|
||||
}
|
||||
for j := 0; j < numToSplit; j++ {
|
||||
if n > i {
|
||||
return n, errors.New("splitting coalesced packet resulted in overflow")
|
||||
}
|
||||
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
||||
msgs[n].N = copied
|
||||
msgs[n].Addr = msg.Addr
|
||||
start = end
|
||||
end += gsoSize
|
||||
if end > msg.N {
|
||||
end = msg.N
|
||||
}
|
||||
n++
|
||||
}
|
||||
if i != n-1 {
|
||||
// It is legal for bytes to move within msg.Buffers[0] as a result
|
||||
// of splitting, so we only zero the source msg len when it is not
|
||||
// the destination of the last split operation above.
|
||||
msg.N = 0
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user