diff --git a/client/iface/configurer/kernel_unix.go b/client/iface/configurer/kernel_unix.go index 7c1c41669..6f09a63c9 100644 --- a/client/iface/configurer/kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -43,13 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error return nil } -func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { - // parse allowed ips - _, ipNet, err := net.ParseCIDR(allowedIps) - if err != nil { - return err - } - +func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -58,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAli PublicKey: peerKeyParsed, ReplaceAllowedIPs: false, // don't replace allowed ips, wg will handle duplicated peer IP - AllowedIPs: []net.IPNet{*ipNet}, + AllowedIPs: allowedIps, PersistentKeepaliveInterval: &keepAlive, Endpoint: endpoint, PresharedKey: preSharedKey, diff --git a/client/iface/configurer/usp.go b/client/iface/configurer/usp.go index 391269dd0..a3de58c24 100644 --- a/client/iface/configurer/usp.go +++ b/client/iface/configurer/usp.go @@ -52,13 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { - // parse allowed ips - _, ipNet, err := net.ParseCIDR(allowedIps) - if err != nil { - return err - } - +func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -67,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAliv PublicKey: peerKeyParsed, ReplaceAllowedIPs: false, // don't replace allowed ips, wg will handle duplicated peer IP - AllowedIPs: []net.IPNet{*ipNet}, + AllowedIPs: allowedIps, PersistentKeepaliveInterval: &keepAlive, PresharedKey: preSharedKey, Endpoint: endpoint, diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go index 0196b0085..6971b6946 100644 --- a/client/iface/device/interface.go +++ b/client/iface/device/interface.go @@ -11,7 +11,7 @@ import ( type WGConfigurer interface { ConfigureInterface(privateKey string, port int) error - UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error diff --git a/client/iface/iface.go b/client/iface/iface.go index 8056dd9a6..40bd51fbb 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -3,6 +3,7 @@ package iface import ( "fmt" "net" + "net/netip" "sync" "time" @@ -112,12 +113,13 @@ func (w *WGIface) UpdateAddr(newAddr string) error { // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist // Endpoint is optional -func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { w.mu.Lock() defer w.mu.Unlock() + netIPNets := prefixesToIPNets(allowedIps) log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) - return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) + return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey) } // RemovePeer removes a Wireguard Peer from the interface iface @@ -250,3 +252,14 @@ func (w *WGIface) GetNet() *netstack.Net { return w.tun.GetNet() } + +func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet { + ipNets := make([]net.IPNet, len(prefixes)) + for i, prefix := range prefixes { + ipNets[i] = net.IPNet{ + IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP + Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask + } + } + return ipNets +} diff --git a/client/iface/iface_test.go b/client/iface/iface_test.go index 85db9cacb..e890b30f3 100644 --- a/client/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -373,12 +373,12 @@ func Test_UpdatePeer(t *testing.T) { t.Fatal(err) } keepAlive := 15 * time.Second - allowedIP := "10.99.99.10/32" + allowedIP := netip.MustParsePrefix("10.99.99.10/32") endpoint, err := net.ResolveUDPAddr("udp", "127.0.0.1:9900") if err != nil { t.Fatal(err) } - err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, endpoint, nil) + err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, endpoint, nil) if err != nil { t.Fatal(err) } @@ -396,7 +396,7 @@ func Test_UpdatePeer(t *testing.T) { var foundAllowedIP bool for _, aip := range peer.AllowedIPs { - if aip.String() == allowedIP { + if aip.String() == allowedIP.String() { foundAllowedIP = true break } @@ -443,9 +443,8 @@ func Test_RemovePeer(t *testing.T) { t.Fatal(err) } keepAlive := 15 * time.Second - allowedIP := "10.99.99.14/32" - - err = iface.UpdatePeer(peerPubKey, allowedIP, keepAlive, nil, nil) + allowedIP := netip.MustParsePrefix("10.99.99.14/32") + err = iface.UpdatePeer(peerPubKey, []netip.Prefix{allowedIP}, keepAlive, nil, nil) if err != nil { t.Fatal(err) } @@ -462,12 +461,12 @@ func Test_RemovePeer(t *testing.T) { func Test_ConnectPeers(t *testing.T) { peer1ifaceName := fmt.Sprintf("utun%d", WgIntNumber+400) - peer1wgIP := "10.99.99.17/30" + peer1wgIP := netip.MustParsePrefix("10.99.99.17/30") peer1Key, _ := wgtypes.GeneratePrivateKey() peer1wgPort := 33100 peer2ifaceName := "utun500" - peer2wgIP := "10.99.99.18/30" + peer2wgIP := netip.MustParsePrefix("10.99.99.18/30") peer2Key, _ := wgtypes.GeneratePrivateKey() peer2wgPort := 33200 @@ -482,7 +481,7 @@ func Test_ConnectPeers(t *testing.T) { optsPeer1 := WGIFaceOpts{ IFaceName: peer1ifaceName, - Address: peer1wgIP, + Address: peer1wgIP.String(), WGPort: peer1wgPort, WGPrivKey: peer1Key.String(), MTU: DefaultMTU, @@ -522,7 +521,7 @@ func Test_ConnectPeers(t *testing.T) { optsPeer2 := WGIFaceOpts{ IFaceName: peer2ifaceName, - Address: peer2wgIP, + Address: peer2wgIP.String(), WGPort: peer2wgPort, WGPrivKey: peer2Key.String(), MTU: DefaultMTU, @@ -558,11 +557,11 @@ func Test_ConnectPeers(t *testing.T) { } }() - err = iface1.UpdatePeer(peer2Key.PublicKey().String(), peer2wgIP, keepAlive, peer2endpoint, nil) + err = iface1.UpdatePeer(peer2Key.PublicKey().String(), []netip.Prefix{peer2wgIP}, keepAlive, peer2endpoint, nil) if err != nil { t.Fatal(err) } - err = iface2.UpdatePeer(peer1Key.PublicKey().String(), peer1wgIP, keepAlive, peer1endpoint, nil) + err = iface2.UpdatePeer(peer1Key.PublicKey().String(), []netip.Prefix{peer1wgIP}, keepAlive, peer1endpoint, nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine.go b/client/internal/engine.go index c939240d9..10c4fb970 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -527,15 +527,18 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { var modified []*mgmProto.RemotePeerConfig for _, p := range peersUpdate { peerPubKey := p.GetWgPubKey() - if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok { - if allowedIPs != strings.Join(p.AllowedIps, ",") { - modified = append(modified, p) - continue - } - err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()) - if err != nil { - log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err) - } + allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey) + if !ok { + continue + } + if !compareNetIPLists(allowedIPs, p.GetAllowedIps()) { + modified = append(modified, p) + continue + } + + err := e.statusRecorder.UpdatePeerFQDN(peerPubKey, p.GetFqdn()) + if err != nil { + log.Warnf("error updating peer's %s fqdn in the status recorder, got error: %v", peerPubKey, err) } } @@ -1103,34 +1106,45 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { // addNewPeer add peer if connection doesn't exist func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { peerKey := peerConfig.GetWgPubKey() - peerIPs := peerConfig.GetAllowedIps() - if _, ok := e.peerStore.PeerConn(peerKey); !ok { - conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) - if err != nil { - return fmt.Errorf("create peer connection: %w", err) - } - - if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { - conn.Close() - return fmt.Errorf("peer already exists: %s", peerKey) - } - - if e.beforePeerHook != nil && e.afterPeerHook != nil { - conn.AddBeforeAddPeerHook(e.beforePeerHook) - conn.AddAfterRemovePeerHook(e.afterPeerHook) - } - - err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) - if err != nil { - log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) - } - - conn.Open() + peerIPs := make([]netip.Prefix, 0, len(peerConfig.GetAllowedIps())) + if _, ok := e.peerStore.PeerConn(peerKey); ok { + return nil } + + for _, ipString := range peerConfig.GetAllowedIps() { + allowedNetIP, err := netip.ParsePrefix(ipString) + if err != nil { + log.Errorf("failed to parse allowedIPS: %v", err) + return err + } + peerIPs = append(peerIPs, allowedNetIP) + } + + conn, err := e.createPeerConn(peerKey, peerIPs) + if err != nil { + return fmt.Errorf("create peer connection: %w", err) + } + + if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { + conn.Close() + return fmt.Errorf("peer already exists: %s", peerKey) + } + + if e.beforePeerHook != nil && e.afterPeerHook != nil { + conn.AddBeforeAddPeerHook(e.beforePeerHook) + conn.AddAfterRemovePeerHook(e.afterPeerHook) + } + + err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) + if err != nil { + log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) + } + + conn.Open() return nil } -func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) { +func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer.Conn, error) { log.Debugf("creating peer connection %s", pubKey) wgConfig := peer.WgConfig{ @@ -1815,3 +1829,36 @@ func getInterfacePrefixes() ([]netip.Prefix, error) { return prefixes, nberrors.FormatErrorOrNil(merr) } + +// compareNetIPLists compares a list of netip.Prefix with a list of strings. +// return true if both lists are equal, false otherwise. +func compareNetIPLists(list1 []netip.Prefix, list2 []string) bool { + if len(list1) != len(list2) { + return false + } + + freq := make(map[string]int, len(list1)) + for _, p := range list1 { + freq[p.String()]++ + } + + for _, s := range list2 { + p, err := netip.ParsePrefix(s) + if err != nil { + return false // invalid prefix in list2. + } + key := p.String() + if freq[key] == 0 { + return false + } + freq[key]-- + } + + // all counts should be zero if lists are equal. + for _, count := range freq { + if count != 0 { + return false + } + } + return true +} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 02c8edea7..54a347e31 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -26,6 +26,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -77,7 +78,7 @@ type MockWGIface struct { ToInterfaceFunc func() *net.Interface UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpdateAddrFunc func(newAddr string) error - UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeerFunc func(peerKey string) error AddAllowedIPFunc func(peerKey string, allowedIP string) error RemoveAllowedIPFunc func(peerKey string, allowedIP string) error @@ -128,7 +129,7 @@ func (m *MockWGIface) UpdateAddr(newAddr string) error { return m.UpdateAddrFunc(newAddr) } -func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } @@ -534,7 +535,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { t.Errorf("expecting Engine.peerConns to contain peer %s", p) } expectedAllowedIPs := strings.Join(p.AllowedIps, ",") - if conn.WgConfig().AllowedIps != expectedAllowedIPs { + if !compareNetIPLists(conn.WgConfig().AllowedIps, p.AllowedIps) { t.Errorf("expecting peer %s to have AllowedIPs= %s, got %s", p.GetWgPubKey(), expectedAllowedIPs, conn.WgConfig().AllowedIps) } @@ -1237,6 +1238,91 @@ func Test_CheckFilesEqual(t *testing.T) { } } +func TestCompareNetIPLists(t *testing.T) { + tests := []struct { + name string + list1 []netip.Prefix + list2 []string + expected bool + }{ + { + name: "both empty", + list1: []netip.Prefix{}, + list2: []string{}, + expected: true, + }, + { + name: "single match ipv4", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + list2: []string{"192.168.0.0/24"}, + expected: true, + }, + { + name: "multiple match ipv4, different order", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("10.0.0.0/8")}, + list2: []string{"10.0.0.0/8", "192.168.1.0/24"}, + expected: true, + }, + { + name: "ipv4 mismatch due to extra element in list2", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + list2: []string{"192.168.1.0/24", "10.0.0.0/8"}, + expected: false, + }, + { + name: "ipv4 mismatch due to duplicate count", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24")}, + list2: []string{"192.168.1.0/24"}, + expected: false, + }, + { + name: "invalid prefix in list2", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + list2: []string{"invalid-prefix"}, + expected: false, + }, + { + name: "ipv4 mismatch because different prefixes", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + list2: []string{"10.0.0.0/8"}, + expected: false, + }, + { + name: "single match ipv6", + list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")}, + list2: []string{"2001:db8::/32"}, + expected: true, + }, + { + name: "multiple match ipv6, different order", + list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32"), netip.MustParsePrefix("fe80::/10")}, + list2: []string{"fe80::/10", "2001:db8::/32"}, + expected: true, + }, + { + name: "mixed ipv4 and ipv6 match", + list1: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("2001:db8::/32")}, + list2: []string{"2001:db8::/32", "192.168.1.0/24"}, + expected: true, + }, + { + name: "ipv6 mismatch with invalid prefix", + list1: []netip.Prefix{netip.MustParsePrefix("2001:db8::/32")}, + list2: []string{"invalid-ipv6"}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := compareNetIPLists(tt.list1, tt.list2) + if result != tt.expected { + t.Errorf("compareNetIPLists(%v, %v) = %v; want %v", tt.list1, tt.list2, result, tt.expected) + } + }) + } +} + func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index a66342707..65b425015 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -2,6 +2,7 @@ package internal import ( "net" + "net/netip" "time" wgdevice "golang.zx2c4.com/wireguard/device" @@ -24,7 +25,7 @@ type wgIfaceBase interface { Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy - UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeer(peerKey string) error AddAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 0337960bb..9b4d1a554 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -5,9 +5,9 @@ import ( "fmt" "math/rand" "net" + "net/netip" "os" "runtime" - "strings" "sync" "time" @@ -56,7 +56,7 @@ type WgConfig struct { WgListenPort int RemoteKey string WgInterface WGIface - AllowedIps string + AllowedIps []netip.Prefix PreSharedKey *wgtypes.Key } @@ -91,11 +91,10 @@ type Conn struct { statusRecorder *Status signaler *Signaler relayManager *relayClient.Manager - allowedIP net.IP handshaker *Handshaker onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) - onDisconnected func(remotePeer string, wgIP string) + onDisconnected func(remotePeer string) statusRelay *AtomicConnStatus statusICE *AtomicConnStatus @@ -120,10 +119,8 @@ type Conn struct { // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) { - allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps) - if err != nil { - log.Errorf("failed to parse allowedIPS: %v", err) - return nil, err + if len(config.WgConfig.AllowedIps) == 0 { + return nil, fmt.Errorf("allowed IPs is empty") } ctx, ctxCancel := context.WithCancel(engineCtx) @@ -137,7 +134,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu statusRecorder: statusRecorder, signaler: signaler, relayManager: relayManager, - allowedIP: allowedIP, statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), semaphore: semaphore, @@ -147,10 +143,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager) relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() - conn.workerICE, err = NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally) + workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally) if err != nil { return nil, err } + conn.workerICE = workerICE conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay) @@ -179,7 +176,7 @@ func (conn *Conn) Open() { peerState := State{ PubKey: conn.config.Key, - IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0], + IP: conn.config.WgConfig.AllowedIps[0].Addr().String(), ConnStatusUpdate: time.Now(), ConnStatus: StatusDisconnected, Mux: new(sync.RWMutex), @@ -245,7 +242,7 @@ func (conn *Conn) Close() { conn.freeUpConnID() if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { - conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps) + conn.onDisconnected(conn.config.WgConfig.RemoteKey) } conn.setStatusToDisconnected() @@ -276,7 +273,7 @@ func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteR } // SetOnDisconnected sets a handler function to be triggered by Conn when a connection to a remote disconnected -func (conn *Conn) SetOnDisconnected(handler func(remotePeer string, wgIP string)) { +func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) { conn.onDisconnected = handler } @@ -601,7 +598,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } if conn.onConnected != nil { - conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr) + conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.config.WgConfig.AllowedIps[0].Addr().String(), remoteRosenpassAddr) } } @@ -698,7 +695,7 @@ func (conn *Conn) freeUpConnID() { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { conn.log.Debugf("setup proxied WireGuard connection") udpAddr := &net.UDPAddr{ - IP: conn.allowedIP, + IP: conn.config.WgConfig.AllowedIps[0].Addr().AsSlice(), Port: conn.config.WgConfig.WgListenPort, } @@ -752,8 +749,8 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { } // AllowedIP returns the allowed IP of the remote peer -func (conn *Conn) AllowedIP() net.IP { - return conn.allowedIP +func (conn *Conn) AllowedIP() netip.Addr { + return conn.config.WgConfig.AllowedIps[0].Addr() } func isController(config ConnConfig) bool { diff --git a/client/internal/peer/iface.go b/client/internal/peer/iface.go index ae6b3bd0a..c7b6de9ea 100644 --- a/client/internal/peer/iface.go +++ b/client/internal/peer/iface.go @@ -2,15 +2,17 @@ package peer import ( "net" + "net/netip" "time" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/wgproxy" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) type WGIface interface { - UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemovePeer(peerKey string) error GetStats(peerKey string) (configurer.WGStats, error) GetProxy() wgproxy.Proxy diff --git a/client/internal/peerstore/store.go b/client/internal/peerstore/store.go index 6b3385ff5..15d34d3d0 100644 --- a/client/internal/peerstore/store.go +++ b/client/internal/peerstore/store.go @@ -1,7 +1,7 @@ package peerstore import ( - "net" + "net/netip" "sync" "golang.org/x/exp/maps" @@ -46,18 +46,7 @@ func (s *Store) Remove(pubKey string) (*peer.Conn, bool) { return p, true } -func (s *Store) AllowedIPs(pubKey string) (string, bool) { - s.peerConnsMu.RLock() - defer s.peerConnsMu.RUnlock() - - p, ok := s.peerConns[pubKey] - if !ok { - return "", false - } - return p.WgConfig().AllowedIps, true -} - -func (s *Store) AllowedIP(pubKey string) (net.IP, bool) { +func (s *Store) AllowedIPs(pubKey string) ([]netip.Prefix, bool) { s.peerConnsMu.RLock() defer s.peerConnsMu.RUnlock() @@ -65,6 +54,17 @@ func (s *Store) AllowedIP(pubKey string) (net.IP, bool) { if !ok { return nil, false } + return p.WgConfig().AllowedIps, true +} + +func (s *Store) AllowedIP(pubKey string) (netip.Addr, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return netip.Addr{}, false + } return p.AllowedIP(), true } diff --git a/client/internal/rosenpass/manager.go b/client/internal/rosenpass/manager.go index bf019453b..d2d7408fd 100644 --- a/client/internal/rosenpass/manager.go +++ b/client/internal/rosenpass/manager.go @@ -126,7 +126,7 @@ func (m *Manager) generateConfig() (rp.Config, error) { return cfg, nil } -func (m *Manager) OnDisconnected(peerKey string, wgIP string) { +func (m *Manager) OnDisconnected(peerKey string) { m.lock.Lock() defer m.lock.Unlock() diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 10cb03f1d..f36285cc4 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -3,7 +3,6 @@ package dnsinterceptor import ( "context" "fmt" - "net" "net/netip" "strings" "sync" @@ -165,14 +164,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { Timeout: 5 * time.Second, Net: "udp", } - upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) reply, _, err := client.ExchangeContext(context.Background(), r, upstream) var answer []dns.RR if reply != nil { answer = reply.Answer } - log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer) + log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) if err != nil { log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) @@ -201,10 +200,10 @@ func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, } } -func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) { +func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) { peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey) if !exists { - return nil, fmt.Errorf("peer connection not found for key: %s", peerKey) + return netip.Addr{}, fmt.Errorf("peer connection not found for key: %s", peerKey) } return peerAllowedIP, nil }