Process routes before peers (#2105)

This commit is contained in:
Viktor Liu
2024-06-19 12:12:11 +02:00
committed by GitHub
parent b679404618
commit 61bc092458
22 changed files with 165 additions and 93 deletions

View File

@@ -28,11 +28,14 @@ type ICEBind struct {
transportNet transport.Net
udpMux *UniversalUDPMuxDefault
filterFn FilterFn
}
func NewICEBind(transportNet transport.Net) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
ib := &ICEBind{
transportNet: transportNet,
filterFn: filterFn,
}
rc := receiverCreator{
@@ -59,8 +62,9 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {

View File

@@ -8,6 +8,8 @@ import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
@@ -17,6 +19,10 @@ import (
"github.com/pion/transport/v3"
)
// FilterFn is a function that filters out candidates based on the address.
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
@@ -34,6 +40,7 @@ type UniversalUDPMuxParams struct {
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
}
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -56,6 +63,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
}
// embed UDPMux
@@ -105,8 +113,68 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr)
}
if isRouted, found := u.addrCache.Load(addr.String()); found {
return u.handleCachedAddress(isRouted.(bool), b, addr)
}
return u.handleUncachedAddress(b, addr)
}
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *udpConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr)
if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err)
return nil
}
a, err := netip.ParseAddr(host)
if err != nil {
log.Errorf("Failed to parse address %s: %v", addr, err)
return nil
}
if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}
return nil
}
func getHostFromAddr(addr net.Addr) (string, error) {
host, _, err := net.SplitHostPort(addr.String())
return host, err
}
// GetSharedConn returns the shared udp conn

View File

@@ -4,17 +4,19 @@ import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) {
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter),
tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
userspaceBind: true,
}
return wgIFace, nil

View File

@@ -7,11 +7,12 @@ import (
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) {
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
if err != nil {
return nil, err
@@ -22,11 +23,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
}
if netstack.IsEnabled() {
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr())
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}

View File

@@ -6,16 +6,18 @@ import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) {
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd),
tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
userspaceBind: true,
}
return wgIFace, nil

View File

@@ -41,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil)
iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -114,7 +114,7 @@ func Test_CreateInterface(t *testing.T) {
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil)
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -149,7 +149,7 @@ func Test_Close(t *testing.T) {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil)
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -182,7 +182,7 @@ func Test_ConfigureInterface(t *testing.T) {
if err != nil {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil)
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -230,7 +230,7 @@ func Test_UpdatePeer(t *testing.T) {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil)
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -291,7 +291,7 @@ func Test_RemovePeer(t *testing.T) {
t.Fatal(err)
}
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil)
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -345,7 +345,7 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil)
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -368,7 +368,7 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil {
t.Fatal(err)
}
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil)
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -8,11 +8,12 @@ import (
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) {
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
if err != nil {
return nil, err
@@ -22,7 +23,7 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
// move the kernel/usp/netstack preference evaluation to upper layer
if netstack.IsEnabled() {
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr())
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.userspaceBind = true
return wgIFace, nil
}
@@ -36,7 +37,7 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
if !tunModuleIsLoaded() {
return nil, fmt.Errorf("couldn't check or load tun module")
}
wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
wgIFace.userspaceBind = true
return wgIFace, nil
}

View File

@@ -5,11 +5,12 @@ import (
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) {
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
if err != nil {
return nil, err
@@ -20,11 +21,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
}
if netstack.IsEnabled() {
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr())
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}

View File

@@ -31,13 +31,13 @@ type wgTunDevice struct {
configurer wgConfigurer
}
func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter) wgTunDevice {
func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice {
return wgTunDevice{
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
iceBind: bind.NewICEBind(transportNet, filterFn),
tunAdapter: tunAdapter,
}
}

View File

@@ -27,14 +27,14 @@ type tunDevice struct {
configurer wgConfigurer
}
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice {
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
return &tunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}

View File

@@ -29,13 +29,13 @@ type tunDevice struct {
configurer wgConfigurer
}
func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int) *tunDevice {
func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice {
return &tunDevice{
name: name,
address: address,
port: port,
key: key,
iceBind: bind.NewICEBind(transportNet),
iceBind: bind.NewICEBind(transportNet, filterFn),
tunFd: tunFd,
}
}

View File

@@ -27,6 +27,8 @@ type tunKernelDevice struct {
link *wgLink
udpMuxConn net.PacketConn
udpMux *bind.UniversalUDPMuxDefault
filterFn bind.FilterFn
}
func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
@@ -96,8 +98,9 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock,
Net: t.transportNet,
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
}
mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)

View File

@@ -30,7 +30,7 @@ type tunNetstackDevice struct {
configurer wgConfigurer
}
func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice {
func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice {
return &tunNetstackDevice{
name: name,
address: address,
@@ -38,7 +38,7 @@ func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string
key: key,
mtu: mtu,
listenAddress: listenAddress,
iceBind: bind.NewICEBind(transportNet),
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}

View File

@@ -29,7 +29,7 @@ type tunUSPDevice struct {
configurer wgConfigurer
}
func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice {
func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
log.Infof("using userspace bind mode")
checkUser()
@@ -40,7 +40,7 @@ func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu i
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}

View File

@@ -29,14 +29,14 @@ type tunDevice struct {
configurer wgConfigurer
}
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice {
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
return &tunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet),
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}