From eaf985624d9ebe979eb2ba05f2c7b80314410f82 Mon Sep 17 00:00:00 2001 From: braginini Date: Wed, 7 Sep 2022 18:39:58 +0200 Subject: [PATCH] Single Mux --- client/internal/engine.go | 6 +-- iface/bind.go | 93 ++++++++++++++++++--------------------- iface/udp_mux.go | 77 +++++++++++++++++++++----------- 3 files changed, 99 insertions(+), 77 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 38a2abcff..9ba1be5c0 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -210,11 +210,11 @@ func (e *Engine) Start() error { } e.iceMux = iceMux - iceHostMux, err := bind.GetICEHostMux() + /*iceHostMux, err := bind.GetICEHostMux() if err != nil { return err - } - e.iceHostMux = iceHostMux + }*/ + e.iceHostMux = iceMux log.Infof("NetBird Engine started listening on WireGuard port %d", *port) diff --git a/iface/bind.go b/iface/bind.go index a63b8d678..7520ff36e 100644 --- a/iface/bind.go +++ b/iface/bind.go @@ -18,10 +18,9 @@ type BindMux interface { } type ICEBind struct { - sharedConn net.PacketConn - sharedConnHost net.PacketConn - iceSrflxMux *UniversalUDPMuxDefault - iceHostMux *UDPMuxDefault + sharedConn net.PacketConn + udpMux *UniversalUDPMuxDefault + iceHostMux *UDPMuxDefault endpointMap map[string]net.PacketConn @@ -31,11 +30,11 @@ type ICEBind struct { func (b *ICEBind) GetICEMux() (UniversalUDPMux, error) { b.mu.Lock() defer b.mu.Unlock() - if b.iceSrflxMux == nil { + if b.udpMux == nil { return nil, fmt.Errorf("ICEBind has not been initialized yet") } - return b.iceSrflxMux, nil + return b.udpMux, nil } func (b *ICEBind) GetICEHostMux() (UDPMux, error) { @@ -55,9 +54,6 @@ func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { if b.sharedConn != nil { return nil, 0, conn.ErrBindAlreadyOpen } - if b.sharedConnHost != nil { - return nil, 0, conn.ErrBindAlreadyOpen - } b.endpointMap = make(map[string]net.PacketConn) @@ -66,24 +62,20 @@ func (b *ICEBind) Open(uport uint16) ([]conn.ReceiveFunc, uint16, error) { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } - ipv4ConnHost, port, err := listenNet("udp4", 0) - if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - return nil, 0, err - } b.sharedConn = ipv4Conn - b.sharedConnHost = ipv4ConnHost - b.iceSrflxMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn}) - b.iceHostMux = NewUDPMuxDefault(UDPMuxParams{UDPConn: b.sharedConnHost}) + b.udpMux = NewUniversalUDPMuxDefault(UniversalUDPMuxParams{UDPConn: b.sharedConn}) - portAddr, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String()) + portAddr1, err := netip.ParseAddrPort(ipv4Conn.LocalAddr().String()) if err != nil { return nil, 0, err } + + log.Infof("opened ICEBind on %s", ipv4Conn.LocalAddr().String()) + return []conn.ReceiveFunc{ - b.makeReceiveIPv4(b.sharedConn, b.iceSrflxMux), - b.makeReceiveIPv4(b.sharedConnHost, b.iceHostMux), + b.makeReceiveIPv4(b.sharedConn), }, - portAddr.Port(), nil + portAddr1.Port(), nil } func listenNet(network string, port int) (*net.UDPConn, int, error) { @@ -104,7 +96,7 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { return conn, uaddr.Port, nil } -func (b *ICEBind) makeReceiveIPv4(c net.PacketConn, bindMux BindMux) conn.ReceiveFunc { +func (b *ICEBind) makeReceiveIPv4(c net.PacketConn) conn.ReceiveFunc { return func(buff []byte) (int, conn.Endpoint, error) { n, endpoint, err := c.ReadFrom(buff) if err != nil { @@ -122,15 +114,37 @@ func (b *ICEBind) makeReceiveIPv4(c net.PacketConn, bindMux BindMux) conn.Receiv Zone: e.Addr().Zone(), }), nil } - b.mu.Lock() + + /* msg := &stun.Message{ + Raw: append([]byte{}, buff[:n]...), + } + if err := msg.Decode(); err != nil { + return 0, nil, err + } + strAttrs := []string{} + for _, attribute := range msg.Attributes { + strAttrs = append(strAttrs, attribute.String()) + } + + xorMapped := "EMPTY" + _, err = msg.Get(stun.AttrXORMappedAddress) + if err == nil { + var addr stun.XORMappedAddress + if err := addr.GetFrom(msg); err == nil { + xorMapped = addr.String() + } + } + + log.Printf("endpoint %s XORMAPPED %s mux type %s msg type %s, attributes %s", endpoint.String(), xorMapped, bindMux.Type(), msg.Type.String(), strings.Join(strAttrs[:], ";")) + */ if _, ok := b.endpointMap[e.String()]; !ok { b.endpointMap[e.String()] = c - log.Infof("added %s endpoint %s", bindMux.Type(), e.String()) + log.Infof("added endpoint %s", e.String()) } b.mu.Unlock() - err = bindMux.HandlePacket(buff, n, endpoint) + err = b.udpMux.HandlePacket(buff, n, endpoint) if err != nil { return 0, nil, err } @@ -147,43 +161,24 @@ func (b *ICEBind) Close() error { b.mu.Lock() defer b.mu.Unlock() - var err1, err2, err3, err4 error + var err1, err2 error if b.sharedConn != nil { c := b.sharedConn b.sharedConn = nil err1 = c.Close() } - if b.sharedConnHost != nil { - c := b.sharedConnHost - b.sharedConnHost = nil - err2 = c.Close() - } - if b.iceSrflxMux != nil { - m := b.iceSrflxMux - b.iceSrflxMux = nil - err3 = m.Close() + if b.udpMux != nil { + m := b.udpMux + b.udpMux = nil + err2 = m.Close() } - if b.iceHostMux != nil { - m := b.iceHostMux - b.iceHostMux = nil - err4 = m.Close() - } - - //todo close iceSrflxMux - if err1 != nil { return err1 } - if err2 != nil { - return err2 - } - if err3 != nil { - return err3 - } - return err4 + return err2 } // SetMark sets the mark for each packet sent through this Bind. diff --git a/iface/udp_mux.go b/iface/udp_mux.go index da44d521d..112331fe1 100644 --- a/iface/udp_mux.go +++ b/iface/udp_mux.go @@ -32,7 +32,7 @@ type UDPMuxDefault struct { conns map[string]*udpMuxedConn addressMapMu sync.RWMutex - addressMap map[string]*udpMuxedConn + addressMap map[string][]*udpMuxedConn // buffer pool to recycle buffers for net.UDPAddr encodes/decodes pool *sync.Pool @@ -55,7 +55,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { } return &UDPMuxDefault{ - addressMap: map[string]*udpMuxedConn{}, + addressMap: map[string][]*udpMuxedConn{}, params: params, conns: make(map[string]*udpMuxedConn), closedChan: make(chan struct{}, 1), @@ -81,11 +81,19 @@ func (m *UDPMuxDefault) HandlePacket(p []byte, n int, addr net.Addr) error { // If we have already seen this address dispatch to the appropriate destination m.addressMapMu.Lock() - destinationConn := m.addressMap[addr.String()] + var destinationConnList []*udpMuxedConn + if storedConns, ok := m.addressMap[addr.String()]; ok { + for _, conn := range storedConns { + destinationConnList = append(destinationConnList, conn) + } + } m.addressMapMu.Unlock() // If we haven't seen this address before but is a STUN packet lookup by ufrag - if destinationConn == nil && stun.IsMessage(p[:20]) { + if stun.IsMessage(p[:20]) { + // This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront. + // However, we can take a username attribute from the STUN message which contains ufrag. + // We can use ufrag to identify the destination conn to route packet to. msg := &stun.Message{ Raw: append([]byte{}, p[:n]...), } @@ -96,25 +104,32 @@ func (m *UDPMuxDefault) HandlePacket(p []byte, n int, addr net.Addr) error { } attr, stunAttrErr := msg.Get(stun.AttrUsername) - if stunAttrErr != nil { - log.Warnf("No Username attribute in STUN message from %s\n", addr.String()) - return stunAttrErr + if stunAttrErr == nil { + ufrag := strings.Split(string(attr), ":")[0] + + m.mu.Lock() + if destinationConn, ok := m.conns[ufrag]; ok { + exists := false + for _, conn := range destinationConnList { + if conn.params.Key == destinationConn.params.Key { + exists = true + break + } + } + if !exists { + destinationConnList = append(destinationConnList, destinationConn) + } + } + m.mu.Unlock() + } else { + //log.Warnf("No Username attribute in STUN message from %s\n", addr.String()) } - - ufrag := strings.Split(string(attr), ":")[0] - - m.mu.Lock() - destinationConn = m.conns[ufrag] - m.mu.Unlock() } - if destinationConn == nil { - log.Tracef("dropping packet from %s, addr: %s", udpAddr.String(), addr.String()) - return nil - } - - if err := destinationConn.writePacket(p[:n], udpAddr); err != nil { - log.Errorf("could not write packet: %v", err) + for _, conn := range destinationConnList { + if err := conn.writePacket(p[:n], udpAddr); err != nil { + log.Errorf("could not write packet: %v", err) + } } return nil @@ -131,6 +146,8 @@ func (m *UDPMuxDefault) GetConn(ufrag string) (net.PacketConn, error) { m.mu.Lock() defer m.mu.Unlock() + log.Debugf("ICE %s: getting muxed connection for %s", m.Type(), ufrag) + if m.IsClosed() { return nil, io.ErrClosedPipe } @@ -219,7 +236,15 @@ func (m *UDPMuxDefault) removeConn(key string) { addresses := c.getAddresses() for _, addr := range addresses { - delete(m.addressMap, addr) + if connList, ok := m.addressMap[addr]; ok { + var newList []*udpMuxedConn + for _, conn := range connList { + if conn.params.Key != key { + newList = append(newList, conn) + } + } + m.addressMap[addr] = newList + } } } @@ -236,12 +261,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) defer m.addressMapMu.Unlock() existing, ok := m.addressMap[addr] - if ok { - existing.removeAddress(addr) + if !ok { + existing = []*udpMuxedConn{} } - m.addressMap[addr] = conn + existing = append(existing, conn) + m.addressMap[addr] = existing - m.params.Logger.Debugf("Registered %s for %s", addr, conn.params.Key) + log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) } func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { @@ -252,6 +278,7 @@ func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { LocalAddr: m.LocalAddr(), Logger: m.params.Logger, }) + log.Debugf("ICE: created muxed connection %s for %s", c.LocalAddr().String(), key) return c }