Implement ICEBind

This commit is contained in:
braginini
2022-09-06 20:06:51 +02:00
parent 9350c5f8d8
commit 2829cce644
9 changed files with 893 additions and 123 deletions

View File

@@ -1,99 +1,157 @@
package iface
import (
"errors"
"fmt"
"github.com/pion/stun"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/conn"
"net"
"net/netip"
"sync"
"syscall"
)
type UserEndpoint struct {
conn.StdNetEndpoint
type ICEBind struct {
sharedConn net.PacketConn
iceMux *UniversalUDPMuxDefault
mu sync.Mutex // protects following fields
}
type packet struct {
buff []byte
addr *net.UDPAddr
func (b *ICEBind) GetSharedConn() (net.PacketConn, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.sharedConn == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return b.sharedConn, nil
}
type UserBind struct {
endpointsLock sync.RWMutex
endpoints map[netip.AddrPort]*UserEndpoint
sharedConn net.PacketConn
func (b *ICEBind) GetICEMux() (UniversalUDPMux, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.iceMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
Packets chan packet
closeSignal chan struct{}
return b.iceMux, nil
}
func NewUserBind(sharedConn net.PacketConn) *UserBind {
return &UserBind{sharedConn: sharedConn}
}
func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) {
b.mu.Lock()
defer b.mu.Unlock()
func (b *UserBind) Open(port uint16) ([]conn.ReceiveFunc, uint16, error) {
if b.sharedConn != nil {
return nil, 0, conn.ErrBindAlreadyOpen
}
b.Packets = make(chan packet, 1000)
b.closeSignal = make(chan struct{})
port := int(uport)
ipv4Conn, port, err := listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
b.sharedConn = ipv4Conn
b.iceMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn})
return []conn.ReceiveFunc{b.receive}, port, nil
}
func (b *UserBind) receive(buff []byte) (int, conn.Endpoint, error) {
/*n, endpoint, err := b.sharedConn.ReadFrom(buff)
portAddr, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String())
if err != nil {
return 0, nil, err
}
e, err := netip.ParseAddrPort(endpoint.String())
if err != nil {
return 0, nil, err
}
return n, (*conn.StdNetEndpoint)(&net.UDPAddr{
IP: e.addr().AsSlice(),
Port: int(e.Port()),
Zone: e.addr().Zone(),
}), err*/
select {
case <-b.closeSignal:
return 0, nil, net.ErrClosed
case pkt := <-b.Packets:
/*log.Infof("received packet %d from %s to copy to buffer %d", binary.Size(pkt.buff), pkt.addr.String(),
len(buff))*/
return copy(buff, pkt.buff), (*conn.StdNetEndpoint)(pkt.addr), nil
return nil, 0, err
}
return []conn.ReceiveFunc{b.makeReceiveIPv4(b.sharedConn)}, portAddr.Port(), nil
}
func (b *UserBind) Close() error {
if b.closeSignal != nil {
select {
case <-b.closeSignal:
default:
close(b.closeSignal)
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// Retrieve port.
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}
func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc {
return func(buff []byte) (int, conn.Endpoint, error) {
n, endpoint, err := c.ReadFrom(buff)
if err != nil {
return 0, nil, err
}
e, err := netip.ParseAddrPort(endpoint.String())
if err != nil {
return 0, nil, err
}
if !stun.IsMessage(buff[:n]) {
// WireGuard traffic
return n, (*conn.StdNetEndpoint)(&net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}), nil
}
err = b.iceMux.HandlePacket(buff, n, endpoint)
if err != nil {
return 0, nil, err
}
if err != nil {
log.Warnf("failed to handle packet")
}
// discard packets because they are STUN related
return 0, nil, nil //todo proper return
}
return nil
}
func (b *ICEBind) Close() error {
b.mu.Lock()
defer b.mu.Unlock()
var err1, err2 error
if b.sharedConn != nil {
c := b.sharedConn
b.sharedConn = nil
err1 = c.Close()
}
if b.iceMux != nil {
m := b.iceMux
b.iceMux = nil
err2 = m.Close()
}
if err1 != nil {
return err1
}
return err2
}
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
func (b *UserBind) SetMark(mark uint32) error {
func (b *ICEBind) SetMark(mark uint32) error {
return nil
}
func (b *UserBind) Send(buff []byte, endpoint conn.Endpoint) error {
func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error {
nend, ok := endpoint.(*conn.StdNetEndpoint)
if !ok {
return conn.ErrWrongEndpointType
}
//log.Infof("sending packet %d from %s to %s", binary.Size(buff), b.sharedConn.LocalAddr().String(), (*net.UDPAddr)(nend).String())
_, err := b.sharedConn.WriteTo(buff, (*net.UDPAddr)(nend))
return err
}
// ParseEndpoint creates a new endpoint from a string.
func (b *UserBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
func (b *ICEBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
e, err := netip.ParseAddrPort(s)
return (*conn.StdNetEndpoint)(&net.UDPAddr{
IP: e.Addr().AsSlice(),
@@ -101,10 +159,3 @@ func (b *UserBind) ParseEndpoint(s string) (ep conn.Endpoint, err error) {
Zone: e.Addr().Zone(),
}), err
}
func (b *UserBind) OnData(buff []byte, addr *net.UDPAddr) {
b.Packets <- packet{
buff: buff,
addr: addr,
}
}