diff --git a/iface/bind/bind.go b/iface/bind/bind.go index 9faf9c81b..4f02bf076 100644 --- a/iface/bind/bind.go +++ b/iface/bind/bind.go @@ -78,9 +78,19 @@ func NewICEBind(transportNet transport.Net) *ICEBind { } } +type StdNetEndpoint struct { + // AddrPort is the endpoint destination. + netip.AddrPort + // src is the current sticky source address and interface index, if supported. + src struct { + netip.Addr + ifidx int32 + } +} + var ( _ wgConn.Bind = (*ICEBind)(nil) - _ wgConn.Endpoint = &wgConn.StdNetEndpoint{} + _ wgConn.Endpoint = &StdNetEndpoint{} ) func (*ICEBind) ParseEndpoint(s string) (wgConn.Endpoint, error) { @@ -88,6 +98,36 @@ func (*ICEBind) ParseEndpoint(s string) (wgConn.Endpoint, error) { return asEndpoint(e), err } +func (e *StdNetEndpoint) ClearSrc() { + e.src.ifidx = 0 + e.src.Addr = netip.Addr{} +} + +func (e *StdNetEndpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() +} + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return e.src.Addr +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + return e.src.ifidx +} + +func (e *StdNetEndpoint) DstToBytes() []byte { + b, _ := e.AddrPort.MarshalBinary() + return b +} + +func (e *StdNetEndpoint) DstToString() string { + return e.AddrPort.String() +} + +func (e *StdNetEndpoint) SrcToString() string { + return e.src.Addr.String() +} + func listenNet(network string, port int) (*net.UDPConn, int, error) { conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { @@ -275,12 +315,12 @@ func (s *ICEBind) send4(conn *ipv4.PacketConn, ep wgConn.Endpoint, buffs [][]byt as4 := ep.DstIP().As4() copy(ua.IP, as4[:]) ua.IP = ua.IP[:4] - ua.Port = int(ep.(*wgConn.StdNetEndpoint).Port()) + ua.Port = int(ep.(*StdNetEndpoint).Port()) msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message) for i, buff := range buffs { (*msgs)[i].Buffers[0] = buff (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*wgConn.StdNetEndpoint)) + setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) } var ( n int @@ -304,12 +344,12 @@ func (s *ICEBind) send6(conn *ipv6.PacketConn, ep wgConn.Endpoint, buffs [][]byt as16 := ep.DstIP().As16() copy(ua.IP, as16[:]) ua.IP = ua.IP[:16] - ua.Port = int(ep.(*wgConn.StdNetEndpoint).Port()) + ua.Port = int(ep.(*StdNetEndpoint).Port()) msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message) for i, buff := range buffs { (*msgs)[i].Buffers[0] = buff (*msgs)[i].Addr = ua - setSrcControl(&(*msgs)[i].OOB, ep.(*wgConn.StdNetEndpoint)) + setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint)) } var ( n int @@ -354,17 +394,17 @@ func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) // but Endpoints are immutable, so we can re-use them. var endpointPool = sync.Pool{ New: func() any { - return make(map[netip.AddrPort]*wgConn.StdNetEndpoint) + return make(map[netip.AddrPort]*StdNetEndpoint) }, } // asEndpoint returns an Endpoint containing ap. -func asEndpoint(ap netip.AddrPort) *wgConn.StdNetEndpoint { - m := endpointPool.Get().(map[netip.AddrPort]*wgConn.StdNetEndpoint) +func asEndpoint(ap netip.AddrPort) *StdNetEndpoint { + m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint) defer endpointPool.Put(m) e, ok := m[ap] if !ok { - e = &wgConn.StdNetEndpoint{AddrPort: ap} + e = &StdNetEndpoint{AddrPort: ap} m[ap] = e } return e