diff --git a/client/internal/engine.go b/client/internal/engine.go index 46fff1da7..6a3940b04 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -89,7 +89,8 @@ type Engine struct { wgInterface *iface.WGIface - iceMux ice.UniversalUDPMux + iceMux ice.UniversalUDPMux + iceHostMux ice.UDPMux // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service networkSerial uint64 @@ -249,6 +250,12 @@ func (e *Engine) Start() error { } e.iceMux = iceMux + iceHostMux, err := bind.GetICEHostMux() + if err != nil { + return err + } + e.iceHostMux = iceHostMux + log.Infof("NetBird Engine started listening on WireGuard port %d", *port) e.receiveSignalEvents() @@ -743,7 +750,7 @@ func (e Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, er StunTurn: stunTurn, InterfaceBlackList: e.config.IFaceBlackList, Timeout: timeout, - UDPMux: e.iceMux, + UDPMux: e.iceHostMux, UDPMuxSrflx: e.iceMux, ProxyConfig: proxyConfig, LocalWgPort: e.config.WgPort, diff --git a/iface/bind.go b/iface/bind.go index 7d3037cb1..87fdae707 100644 --- a/iface/bind.go +++ b/iface/bind.go @@ -12,31 +12,40 @@ import ( "syscall" ) -type ICEBind struct { - sharedConn net.PacketConn - iceMux *UniversalUDPMuxDefault - - mu sync.Mutex // protects following fields +type BindMux interface { + HandlePacket(p []byte, n int, addr net.Addr) error + Type() string } -func (b *ICEBind) GetSharedConn() (net.PacketConn, error) { - b.mu.Lock() - defer b.mu.Unlock() - if b.sharedConn == nil { - return nil, fmt.Errorf("ICEBind has not been initialized yet") - } +type ICEBind struct { + sharedConn net.PacketConn + sharedConnHost net.PacketConn + iceSrflxMux *UniversalUDPMuxDefault + iceHostMux *UDPMuxDefault - return b.sharedConn, nil + endpointMap map[string]net.PacketConn + + mu sync.Mutex // protects following fields } func (b *ICEBind) GetICEMux() (UniversalUDPMux, error) { b.mu.Lock() defer b.mu.Unlock() - if b.iceMux == nil { + if b.iceSrflxMux == nil { return nil, fmt.Errorf("ICEBind has not been initialized yet") } - return b.iceMux, nil + return b.iceSrflxMux, nil +} + +func (b *ICEBind) GetICEHostMux() (UDPMux, error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.iceHostMux == nil { + return nil, fmt.Errorf("ICEBind has not been initialized yet") + } + + return b.iceHostMux, nil } func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { @@ -46,20 +55,35 @@ func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { if b.sharedConn != nil { return nil, 0, conn.ErrBindAlreadyOpen } + if b.sharedConnHost != nil { + return nil, 0, conn.ErrBindAlreadyOpen + } + + b.endpointMap = make(map[string]net.PacketConn) port := int(uport) ipv4Conn, port, err := listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } + ipv4ConnHost, port, err := listenNet("udp4", 0) + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + return nil, 0, err + } b.sharedConn = ipv4Conn - b.iceMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn}) + b.sharedConnHost = ipv4ConnHost + b.iceSrflxMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn}) + b.iceHostMux = NewUDPMuxDefault(UDPMuxParams{UDPConn: b.sharedConnHost}) portAddr, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String()) if err != nil { return nil, 0, err } - return []conn.ReceiveFunc{b.makeReceiveIPv4(b.sharedConn)}, portAddr.Port(), nil + return []conn.ReceiveFunc{ + b.makeReceiveIPv4(b.sharedConn, b.iceSrflxMux), + b.makeReceiveIPv4(b.sharedConnHost, b.iceHostMux), + }, + portAddr.Port(), nil } func listenNet(network string, port int) (*net.UDPConn, int, error) { @@ -80,7 +104,7 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { return conn, uaddr.Port, nil } -func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc { +func (b *ICEBind) makeReceiveIPv4(c net.PacketConn, bindMux BindMux) conn.ReceiveFunc { return func(buff []byte) (int, conn.Endpoint, error) { n, endpoint, err := c.ReadFrom(buff) if err != nil { @@ -99,7 +123,14 @@ func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc { }), nil } - err = b.iceMux.HandlePacket(buff, n, endpoint) + b.mu.Lock() + if _, ok := b.endpointMap[e.String()]; !ok { + b.endpointMap[e.String()] = c + log.Infof("added %s endpoint %s", bindMux.Type(), e.Addr().String()) + } + b.mu.Unlock() + + err = bindMux.HandlePacket(buff, n, endpoint) if err != nil { return 0, nil, err } @@ -116,23 +147,43 @@ func (b *ICEBind) Close() error { b.mu.Lock() defer b.mu.Unlock() - var err1, err2 error + var err1, err2, err3, err4 error if b.sharedConn != nil { c := b.sharedConn b.sharedConn = nil err1 = c.Close() } - - if b.iceMux != nil { - m := b.iceMux - b.iceMux = nil - err2 = m.Close() + if b.sharedConnHost != nil { + c := b.sharedConnHost + b.sharedConnHost = nil + err2 = c.Close() } + if b.iceSrflxMux != nil { + m := b.iceSrflxMux + b.iceSrflxMux = nil + err3 = m.Close() + } + + if b.iceHostMux != nil { + m := b.iceHostMux + b.iceHostMux = nil + err4 = m.Close() + } + + //todo close iceSrflxMux + if err1 != nil { return err1 } - return err2 + if err2 != nil { + return err2 + } + if err3 != nil { + return err3 + } + + return err4 } // SetMark sets the mark for each packet sent through this Bind. @@ -146,7 +197,17 @@ func (b *ICEBind) Send(buff []byte, endpoint conn.Endpoint) error { if !ok { return conn.ErrWrongEndpointType } - _, err := b.sharedConn.WriteTo(buff, (*net.UDPAddr)(nend)) + + b.mu.Lock() + co := b.endpointMap[(*net.UDPAddr)(nend).String()] + b.mu.Unlock() + if co == nil { + // todo proper handling + log.Warnf("conn not found for endpoint %s", endpoint.DstToString()) + return nil + } + + _, err := co.WriteTo(buff, (*net.UDPAddr)(nend)) return err } diff --git a/iface/udp_mux.go b/iface/udp_mux.go index ba917aeb4..040a216e0 100644 --- a/iface/udp_mux.go +++ b/iface/udp_mux.go @@ -68,6 +68,10 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { } } +func (m *UDPMuxDefault) Type() string { + return "HOST" +} + func (m *UDPMuxDefault) HandlePacket(p []byte, n int, addr net.Addr) error { udpAddr, ok := addr.(*net.UDPAddr) diff --git a/iface/udp_mux_universal.go b/iface/udp_mux_universal.go index e2ff55f68..d33bb4171 100644 --- a/iface/udp_mux_universal.go +++ b/iface/udp_mux_universal.go @@ -75,6 +75,10 @@ func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string) (net.Pa return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url)) } +func (m *UniversalUDPMuxDefault) Type() string { + return "SRFLX" +} + func (m *UniversalUDPMuxDefault) HandlePacket(p []byte, n int, addr net.Addr) error { if stun.IsMessage(p[:n]) { msg := &stun.Message{