diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index f8afd3d6e..d7007c860 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -32,6 +32,9 @@ jobs: restore-keys: | macos-go- + - name: Install libpcap + run: brew install libpcap + - name: Install modules run: go mod tidy diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 6027d3626..2d63acbcd 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -46,7 +46,7 @@ jobs: - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt diff --git a/README.md b/README.md index 9b52f5b5f..d2a2bd6b9 100644 --- a/README.md +++ b/README.md @@ -40,11 +40,12 @@ **Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth. -**Secure.** NetBird enables secure remote access by applying granular access policies, while allowing you to manage them intuitively from a single place. Works universally on any infrastructure. +**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure. ### Open-Source Network Security in a Single Platform -![download (2)](https://github.com/netbirdio/netbird/assets/700848/16210ac2-7265-44c1-8d4e-8fae85534dac) +![image](https://github.com/netbirdio/netbird/assets/700848/c0d7bae4-3301-499a-bb4e-5e4a225bf35f) + ### Key features @@ -76,7 +77,7 @@ Follow the [Advanced guide with a custom identity provider](https://docs.netbird - **Public domain** name pointing to the VM. **Software requirements:** -- Docker installed on the VM with the docker compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher. +- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher. - [jq](https://jqlang.github.io/jq/) installed. In most distributions Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq` - [curl](https://curl.se/) installed. @@ -93,9 +94,9 @@ export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbird - Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard. - Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers). - NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines. -- Connection candidates are discovered with a help of [STUN](https://en.wikipedia.org/wiki/STUN) servers. +- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers. - Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates. -- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server. +- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server. [Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups. @@ -119,7 +120,7 @@ In November 2022, NetBird joined the [StartUpSecure program](https://www.forschu ![CISPA_Logo_BLACK_EN_RZ_RGB (1)](https://user-images.githubusercontent.com/700848/203091324-c6d311a0-22b5-4b05-a288-91cbc6cdcc46.png) ### Testimonials -We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g. giving a star or a contribution). +We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution). ### Legal _WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld. diff --git a/client/cmd/root.go b/client/cmd/root.go index c3ff0a3c8..9c4ad99de 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -34,6 +34,7 @@ const ( wireguardPortFlag = "wireguard-port" disableAutoConnectFlag = "disable-auto-connect" serverSSHAllowedFlag = "allow-server-ssh" + extraIFaceBlackListFlag = "extra-iface-blacklist" ) var ( @@ -63,6 +64,7 @@ var ( wireguardPort uint16 serviceName string autoConnectDisabled bool + extraIFaceBlackList []string rootCmd = &cobra.Command{ Use: "netbird", Short: "", diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index da6beef4f..5e147262b 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -64,6 +64,10 @@ var installCmd = &cobra.Command{ } } + if runtime.GOOS == "windows" { + svcConfig.Option["OnFailure"] = "restart" + } + ctx, cancel := context.WithCancel(cmd.Context()) s, err := newSVC(newProgram(ctx, cancel), svcConfig) diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 2cfc93415..2f92e1c03 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -13,6 +13,7 @@ import ( "google.golang.org/grpc" + "github.com/netbirdio/management-integrations/integrations" clientProto "github.com/netbirdio/netbird/client/proto" client "github.com/netbirdio/netbird/client/server" mgmtProto "github.com/netbirdio/netbird/management/proto" @@ -78,7 +79,8 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste if err != nil { return nil, nil } - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + iv, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv) if err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index f44f29a47..c2c3c7c90 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -40,6 +40,7 @@ func init() { upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") + upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") } func upFunc(cmd *cobra.Command, args []string) error { @@ -83,11 +84,12 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { } ic := internal.ConfigInput{ - ManagementURL: managementURL, - AdminURL: adminURL, - ConfigPath: configPath, - NATExternalIPs: natExternalIPs, - CustomDNSAddress: customDNSAddressConverted, + ManagementURL: managementURL, + AdminURL: adminURL, + ConfigPath: configPath, + NATExternalIPs: natExternalIPs, + CustomDNSAddress: customDNSAddressConverted, + ExtraIFaceBlackList: extraIFaceBlackList, } if cmd.Flag(enableRosenpassFlag).Changed { @@ -149,7 +151,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { } func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { - customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed) if err != nil { return err @@ -190,6 +191,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { CustomDNSAddress: customDNSAddressConverted, IsLinuxDesktopClient: isLinuxRunningDesktop(), Hostname: hostName, + ExtraIFaceBlacklist: extraIFaceBlackList, } if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { diff --git a/client/internal/config.go b/client/internal/config.go index 2f6958235..5b3c61cbd 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -30,8 +30,10 @@ const ( DefaultAdminURL = "https://app.netbird.io:443" ) -var defaultInterfaceBlacklist = []string{iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", - "Tailscale", "tailscale", "docker", "veth", "br-", "lo"} +var defaultInterfaceBlacklist = []string{ + iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", + "Tailscale", "tailscale", "docker", "veth", "br-", "lo", +} // ConfigInput carries configuration changes to the client type ConfigInput struct { @@ -47,6 +49,7 @@ type ConfigInput struct { InterfaceName *string WireguardPort *int DisableAutoConnect *bool + ExtraIFaceBlackList []string } // Config Configuration type @@ -220,7 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) { config.AdminURL = newURL } - config.IFaceBlackList = defaultInterfaceBlacklist + // nolint:gocritic + config.IFaceBlackList = append(defaultInterfaceBlacklist, input.ExtraIFaceBlackList...) return config, nil } @@ -320,6 +324,13 @@ func update(input ConfigInput) (*Config, error) { refresh = true } + if len(input.ExtraIFaceBlackList) > 0 { + for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) { + config.IFaceBlackList = append(config.IFaceBlackList, iFace) + refresh = true + } + } + if refresh { // since we have new management URL, we need to update config file if err := util.WriteJson(input.ConfigPath, config); err != nil { @@ -384,7 +395,6 @@ func configFileIsExists(path string) bool { // If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config. // The check is performed only for the NetBird's managed version. func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) { - defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL) if err != nil { return nil, err diff --git a/client/internal/config_test.go b/client/internal/config_test.go index 7453c8fdf..978d0b3df 100644 --- a/client/internal/config_test.go +++ b/client/internal/config_test.go @@ -18,7 +18,6 @@ func TestGetConfig(t *testing.T) { config, err := UpdateOrCreateConfig(ConfigInput{ ConfigPath: filepath.Join(t.TempDir(), "config.json"), }) - if err != nil { return } @@ -86,6 +85,26 @@ func TestGetConfig(t *testing.T) { assert.Equal(t, readConf.(*Config).ManagementURL.String(), newManagementURL) } +func TestExtraIFaceBlackList(t *testing.T) { + extraIFaceBlackList := []string{"eth1"} + path := filepath.Join(t.TempDir(), "config.json") + config, err := UpdateOrCreateConfig(ConfigInput{ + ConfigPath: path, + ExtraIFaceBlackList: extraIFaceBlackList, + }) + if err != nil { + return + } + + assert.Contains(t, config.IFaceBlackList, "eth1") + readConf, err := util.ReadJson(path, config) + if err != nil { + return + } + + assert.Contains(t, readConf.(*Config).IFaceBlackList, "eth1") +} + func TestHiddenPreSharedKey(t *testing.T) { hidden := "**********" samplePreSharedKey := "mysecretpresharedkey" @@ -111,7 +130,6 @@ func TestHiddenPreSharedKey(t *testing.T) { ConfigPath: cfgFile, PreSharedKey: tt.preSharedKey, }) - if err != nil { t.Fatalf("failed to get cfg: %s", err) } diff --git a/client/internal/connect.go b/client/internal/connect.go index 682a1efed..6b888c9cc 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "runtime" + "runtime/debug" "strings" "time" @@ -93,7 +95,13 @@ func runClient( relayProbe *Probe, wgProbe *Probe, ) error { - log.Infof("starting NetBird client version %s", version.NetbirdVersion()) + defer func() { + if r := recover(); r != nil { + log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) + } + }() + + log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) // Check if client was not shut down in a clean way and restore DNS config if required. // Otherwise, we might not be able to connect to the management server to retrieve new config. diff --git a/client/internal/engine.go b/client/internal/engine.go index 7f7b5ef55..d6238c4b3 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -93,6 +93,10 @@ type Engine struct { mgmClient mgm.Client // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn + + beforePeerHook peer.BeforeAddPeerHookFunc + afterPeerHook peer.AfterRemovePeerHookFunc + // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -260,10 +264,14 @@ func (e *Engine) Start() error { e.dnsServer = dnsServer e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes) - if err := e.routeManager.Init(); err != nil { - e.close() - return fmt.Errorf("init route manager: %w", err) + beforePeerHook, afterPeerHook, err := e.routeManager.Init() + if err != nil { + log.Errorf("Failed to initialize route manager: %s", err) + } else { + e.beforePeerHook = beforePeerHook + e.afterPeerHook = afterPeerHook } + e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) err = e.wgInterfaceCreate() @@ -809,10 +817,15 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { if _, ok := e.peerConns[peerKey]; !ok { conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) if err != nil { - return err + return fmt.Errorf("create peer connection: %w", err) } e.peerConns[peerKey] = conn + if e.beforePeerHook != nil && e.afterPeerHook != nil { + conn.AddBeforeAddPeerHook(e.beforePeerHook) + conn.AddAfterRemovePeerHook(e.afterPeerHook) + } + err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn) if err != nil { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) @@ -1106,6 +1119,10 @@ func (e *Engine) close() { e.dnsServer.Stop() } + if e.routeManager != nil { + e.routeManager.Stop() + } + log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { @@ -1120,10 +1137,6 @@ func (e *Engine) close() { } } - if e.routeManager != nil { - e.routeManager.Stop() - } - if e.firewall != nil { err := e.firewall.Reset() if err != nil { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 952b3c90c..309b2e7c6 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -21,6 +21,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -1050,7 +1051,8 @@ func startManagement(dataDir string) (*grpc.Server, string, error) { if err != nil { return nil, "", err } - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index c180e8f03..f3d07dcad 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -20,12 +20,15 @@ import ( "github.com/netbirdio/netbird/iface/bind" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" + nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" ) const ( iceKeepAliveDefault = 4 * time.Second iceDisconnectedTimeoutDefault = 6 * time.Second + // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package + iceRelayAcceptanceMinWaitDefault = 2 * time.Second defaultWgKeepAlive = 25 * time.Second ) @@ -98,6 +101,9 @@ type IceCredentials struct { Pwd string } +type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error +type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error + type Conn struct { config ConnConfig mu sync.Mutex @@ -136,6 +142,10 @@ type Conn struct { remoteEndpoint *net.UDPAddr remoteConn *ice.Conn + + connID nbnet.ConnectionID + beforeAddPeerHooks []BeforeAddPeerHookFunc + afterRemovePeerHooks []AfterRemovePeerHookFunc } // meta holds meta information about a connection @@ -196,20 +206,22 @@ func (conn *Conn) reCreateAgent() error { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: conn.config.StunTurn, - CandidateTypes: conn.candidateTypes(), - FailedTimeout: &failedTimeout, - InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList), - UDPMux: conn.config.UDPMux, - UDPMuxSrflx: conn.config.UDPMuxSrflx, - NAT1To1IPs: conn.config.NATExternalIPs, - Net: transportNet, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: conn.config.StunTurn, + CandidateTypes: conn.candidateTypes(), + FailedTimeout: &failedTimeout, + InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList), + UDPMux: conn.config.UDPMux, + UDPMuxSrflx: conn.config.UDPMuxSrflx, + NAT1To1IPs: conn.config.NATExternalIPs, + Net: transportNet, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, } if conn.config.DisableIPv6Discovery { @@ -389,6 +401,14 @@ func isRelayCandidate(candidate ice.Candidate) bool { return candidate.Type() == ice.CandidateTypeRelay } +func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) { + conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) +} + +func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) { + conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) +} + // configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) { conn.mu.Lock() @@ -415,6 +435,14 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) conn.remoteEndpoint = endpointUdpAddr + log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + + conn.connID = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { + log.Errorf("Before add peer hook failed: %v", err) + } + } err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) if err != nil { @@ -506,6 +534,15 @@ func (conn *Conn) cleanup() error { // todo: is it problem if we try to remove a peer what is never existed? err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) + if conn.connID != "" { + for _, hook := range conn.afterRemovePeerHooks { + if err := hook(conn.connID); err != nil { + log.Errorf("After remove peer hook failed: %v", err) + } + } + } + conn.connID = "" + if conn.notifyDisconnected != nil { conn.notifyDisconnected() conn.notifyDisconnected = nil diff --git a/client/internal/peer/env_config.go b/client/internal/peer/env_config.go index 540bc413e..87b626df7 100644 --- a/client/internal/peer/env_config.go +++ b/client/internal/peer/env_config.go @@ -10,9 +10,10 @@ import ( ) const ( - envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" - envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" - envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" + envICEKeepAliveIntervalSec = "NB_ICE_KEEP_ALIVE_INTERVAL_SEC" + envICEDisconnectedTimeoutSec = "NB_ICE_DISCONNECTED_TIMEOUT_SEC" + envICERelayAcceptanceMinWaitSec = "NB_ICE_RELAY_ACCEPTANCE_MIN_WAIT_SEC" + envICEForceRelayConn = "NB_ICE_FORCE_RELAY_CONN" ) func iceKeepAlive() time.Duration { @@ -21,7 +22,7 @@ func iceKeepAlive() time.Duration { return iceKeepAliveDefault } - log.Debugf("setting ICE keep alive interval to %s seconds", keepAliveEnv) + log.Infof("setting ICE keep alive interval to %s seconds", keepAliveEnv) keepAliveEnvSec, err := strconv.Atoi(keepAliveEnv) if err != nil { log.Warnf("invalid value %s set for %s, using default %v", keepAliveEnv, envICEKeepAliveIntervalSec, iceKeepAliveDefault) @@ -37,7 +38,7 @@ func iceDisconnectedTimeout() time.Duration { return iceDisconnectedTimeoutDefault } - log.Debugf("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv) + log.Infof("setting ICE disconnected timeout to %s seconds", disconnectedTimeoutEnv) disconnectedTimeoutSec, err := strconv.Atoi(disconnectedTimeoutEnv) if err != nil { log.Warnf("invalid value %s set for %s, using default %v", disconnectedTimeoutEnv, envICEDisconnectedTimeoutSec, iceDisconnectedTimeoutDefault) @@ -47,6 +48,22 @@ func iceDisconnectedTimeout() time.Duration { return time.Duration(disconnectedTimeoutSec) * time.Second } +func iceRelayAcceptanceMinWait() time.Duration { + iceRelayAcceptanceMinWaitEnv := os.Getenv(envICERelayAcceptanceMinWaitSec) + if iceRelayAcceptanceMinWaitEnv == "" { + return iceRelayAcceptanceMinWaitDefault + } + + log.Infof("setting ICE relay acceptance min wait to %s seconds", iceRelayAcceptanceMinWaitEnv) + disconnectedTimeoutSec, err := strconv.Atoi(iceRelayAcceptanceMinWaitEnv) + if err != nil { + log.Warnf("invalid value %s set for %s, using default %v", iceRelayAcceptanceMinWaitEnv, envICERelayAcceptanceMinWaitSec, iceRelayAcceptanceMinWaitDefault) + return iceRelayAcceptanceMinWaitDefault + } + + return time.Duration(disconnectedTimeoutSec) * time.Second +} + func hasICEForceRelayConn() bool { disconnectedTimeoutEnv := os.Getenv(envICEForceRelayConn) return strings.ToLower(disconnectedTimeoutEnv) == "true" diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index b2dff7f08..370ad5cf4 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "time" log "github.com/sirupsen/logrus" @@ -18,6 +19,7 @@ type routerPeerStatus struct { connected bool relayed bool direct bool + latency time.Duration } type routesUpdate struct { @@ -68,6 +70,7 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { connected: peerStatus.ConnStatus == peer.StatusConnected, relayed: peerStatus.Relayed, direct: peerStatus.Direct, + latency: peerStatus.Latency, } } return routePeerStatuses @@ -83,11 +86,13 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus { // * Non-relayed: Routes without relays are preferred. // * Direct connections: Routes with direct peer connections are favored. // * Stability: In case of equal scores, the currently active route (if any) is maintained. +// * Latency: Routes with lower latency are prioritized. // // It returns the ID of the selected optimal route. func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string { chosen := "" - chosenScore := 0 + chosenScore := float64(0) + currScore := float64(0) currID := "" if c.chosenRoute != nil { @@ -95,7 +100,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro } for _, r := range c.routes { - tempScore := 0 + tempScore := float64(0) peerStatus, found := routePeerStatuses[r.ID] if !found || !peerStatus.connected { continue @@ -103,9 +108,18 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro if r.Metric < route.MaxMetric { metricDiff := route.MaxMetric - r.Metric - tempScore = metricDiff * 10 + tempScore = float64(metricDiff) * 10 } + // in some temporal cases, latency can be 0, so we set it to 1s to not block but try to avoid this route + latency := time.Second + if peerStatus.latency != 0 { + latency = peerStatus.latency + } else { + log.Warnf("peer %s has 0 latency", r.Peer) + } + tempScore += 1 - latency.Seconds() + if !peerStatus.relayed { tempScore++ } @@ -114,7 +128,7 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro tempScore++ } - if tempScore > chosenScore || (tempScore == chosenScore && r.ID == currID) { + if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") { chosen = r.ID chosenScore = tempScore } @@ -123,18 +137,26 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]ro chosen = r.ID chosenScore = tempScore } + + if r.ID == currID { + currScore = tempScore + } } - if chosen == "" { + switch { + case chosen == "": var peers []string for _, r := range c.routes { peers = append(peers, r.Peer) } log.Warnf("the network %s has not been assigned a routing peer as no peers from the list %s are currently connected", c.network, peers) - - } else if chosen != currID { - log.Infof("new chosen route is %s with peer %s with score %d for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network) + case chosen != currID: + if currScore != 0 && currScore < chosenScore+0.1 { + return currID + } else { + log.Infof("new chosen route is %s with peer %s with score %f for network %s", chosen, c.routes[chosen].Peer, chosenScore, c.network) + } } return chosen @@ -193,7 +215,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error { func (c *clientNetwork) removeRouteFromPeerAndSystem() error { if c.chosenRoute != nil { - if err := removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { + if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil { return fmt.Errorf("remove route %s from system, err: %v", c.network, err) } @@ -234,7 +256,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error { } } else { // otherwise add the route to the system - if err := addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String(), c.wgInterface.Name()); err != nil { + if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil { return fmt.Errorf("route %s couldn't be added for peer %s, err: %v", c.network.String(), c.wgInterface.Address().IP.String(), err) } diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go index 3700d72ec..d24d42b8e 100644 --- a/client/internal/routemanager/client_test.go +++ b/client/internal/routemanager/client_test.go @@ -3,6 +3,7 @@ package routemanager import ( "net/netip" "testing" + "time" "github.com/netbirdio/netbird/route" ) @@ -13,7 +14,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { name string statuses map[string]routerPeerStatus expectedRouteID string - currentRoute *route.Route + currentRoute string existingRoutes map[string]*route.Route }{ { @@ -32,7 +33,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -51,7 +52,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -70,7 +71,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -89,7 +90,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer1", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "", }, { @@ -118,7 +119,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer2", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -147,7 +148,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer2", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, { @@ -176,18 +177,141 @@ func TestGetBestrouteFromStatuses(t *testing.T) { Peer: "peer2", }, }, - currentRoute: nil, + currentRoute: "", expectedRouteID: "route1", }, + { + name: "multiple connected peers with different latencies", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + latency: 300 * time.Millisecond, + }, + "route2": { + connected: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "should ignore routes with latency 0", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + latency: 0 * time.Millisecond, + }, + "route2": { + connected: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "", + expectedRouteID: "route2", + }, + { + name: "current route with similar score and similar but slightly worse latency should not change", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + relayed: false, + direct: true, + latency: 12 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: false, + direct: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, + { + name: "current chosen route doesn't exist anymore", + statuses: map[string]routerPeerStatus{ + "route1": { + connected: true, + relayed: false, + direct: true, + latency: 20 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: false, + direct: true, + latency: 10 * time.Millisecond, + }, + }, + existingRoutes: map[string]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "routeDoesntExistAnymore", + expectedRouteID: "route2", + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + currentRoute := &route.Route{ + ID: "routeDoesntExistAnymore", + } + if tc.currentRoute != "" { + currentRoute = tc.existingRoutes[tc.currentRoute] + } + // create new clientNetwork client := &clientNetwork{ network: netip.MustParsePrefix("192.168.0.0/24"), routes: tc.existingRoutes, - chosenRoute: tc.currentRoute, + chosenRoute: currentRoute, } chosenRoute := client.getBestRouteFromStatuses(tc.statuses) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 6a0d954da..36a37f02c 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -3,7 +3,9 @@ package routemanager import ( "context" "fmt" + "net" "net/netip" + "net/url" "runtime" "sync" @@ -24,7 +26,7 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0) // Manager is a route manager interface type Manager interface { - Init() error + Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string @@ -65,16 +67,21 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface, } // Init sets up the routing -func (m *DefaultManager) Init() error { +func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { if err := cleanupRouting(); err != nil { log.Warnf("Failed cleaning up routing: %v", err) } - if err := setupRouting(); err != nil { - return fmt.Errorf("setup routing: %w", err) + mgmtAddress := m.statusRecorder.GetManagementState().URL + signalAddress := m.statusRecorder.GetSignalState().URL + ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) + + beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface) + if err != nil { + return nil, nil, fmt.Errorf("setup routing: %w", err) } log.Info("Routing setup complete") - return nil + return beforePeerHook, afterPeerHook, nil } func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { @@ -203,16 +210,36 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou } func isPrefixSupported(prefix netip.Prefix) bool { - if runtime.GOOS == "linux" { + switch runtime.GOOS { + case "linux", "windows", "darwin": return true } // If prefix is too small, lets assume it is a possible default prefix which is not yet supported // we skip this prefix management - if prefix.Bits() < minRangeBits { + if prefix.Bits() <= minRangeBits { log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix", version.NetbirdVersion(), prefix) return false } return true } + +// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs. +func resolveURLsToIPs(urls []string) []net.IP { + var ips []net.IP + for _, rawurl := range urls { + u, err := url.Parse(rawurl) + if err != nil { + log.Errorf("Failed to parse url %s: %v", rawurl, err) + continue + } + ipAddrs, err := net.LookupIP(u.Hostname()) + if err != nil { + log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err) + continue + } + ips = append(ips, ipAddrs...) + } + return ips +} diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 9d92bf90d..03e77e09b 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -28,14 +28,14 @@ const remotePeerKey2 = "remote1" func TestManagerUpdateRoutes(t *testing.T) { testCases := []struct { - name string - inputInitRoutes []*route.Route - inputRoutes []*route.Route - inputSerial uint64 - removeSrvRouter bool - serverRoutesExpected int - clientNetworkWatchersExpected int - clientNetworkWatchersExpectedLinux int + name string + inputInitRoutes []*route.Route + inputRoutes []*route.Route + inputSerial uint64 + removeSrvRouter bool + serverRoutesExpected int + clientNetworkWatchersExpected int + clientNetworkWatchersExpectedAllowed int }{ { name: "Should create 2 client networks", @@ -201,9 +201,9 @@ func TestManagerUpdateRoutes(t *testing.T) { Enabled: true, }, }, - inputSerial: 1, - clientNetworkWatchersExpected: 0, - clientNetworkWatchersExpectedLinux: 1, + inputSerial: 1, + clientNetworkWatchersExpected: 0, + clientNetworkWatchersExpectedAllowed: 1, }, { name: "Remove 1 Client Route", @@ -417,7 +417,9 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil) - err = routeManager.Init() + + _, _, err = routeManager.Init() + require.NoError(t, err, "should init route manager") defer routeManager.Stop() @@ -434,8 +436,8 @@ func TestManagerUpdateRoutes(t *testing.T) { require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected - if runtime.GOOS == "linux" && testCase.clientNetworkWatchersExpectedLinux != 0 { - expectedWatchers = testCase.clientNetworkWatchersExpectedLinux + if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 { + expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed } require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index e812b3a85..dd2c28e59 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -6,6 +6,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/listener" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -16,8 +17,8 @@ type MockManager struct { StopFunc func() } -func (m *MockManager) Init() error { - return nil +func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil } // InitialRouteRange mock implementation of InitialRouteRange from Manager interface diff --git a/client/internal/routemanager/routemanager.go b/client/internal/routemanager/routemanager.go new file mode 100644 index 000000000..8f9ff9f4b --- /dev/null +++ b/client/internal/routemanager/routemanager.go @@ -0,0 +1,126 @@ +//go:build !android && !ios + +package routemanager + +import ( + "errors" + "fmt" + "net/netip" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type ref struct { + count int + nexthop netip.Addr + intf string +} + +type RouteManager struct { + // refCountMap keeps track of the reference ref for prefixes + refCountMap map[netip.Prefix]ref + // prefixMap keeps track of the prefixes associated with a connection ID for removal + prefixMap map[nbnet.ConnectionID][]netip.Prefix + addRoute AddRouteFunc + removeRoute RemoveRouteFunc + mutex sync.Mutex +} + +type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error) +type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error + +func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager { + // TODO: read initial routing table into refCountMap + return &RouteManager{ + refCountMap: map[netip.Prefix]ref{}, + prefixMap: map[nbnet.ConnectionID][]netip.Prefix{}, + addRoute: addRoute, + removeRoute: removeRoute, + } +} + +func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + ref := rm.refCountMap[prefix] + log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix) + + // Add route to the system, only if it's a new prefix + if ref.count == 0 { + log.Debugf("Adding route for prefix %s", prefix) + nexthop, intf, err := rm.addRoute(prefix) + if errors.Is(err, ErrRouteNotFound) { + return nil + } + if errors.Is(err, ErrRouteNotAllowed) { + log.Debugf("Adding route for prefix %s: %s", prefix, err) + } + if err != nil { + return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err) + } + ref.nexthop = nexthop + ref.intf = intf + } + + ref.count++ + rm.refCountMap[prefix] = ref + rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix) + + return nil +} + +func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + prefixes, ok := rm.prefixMap[connID] + if !ok { + log.Debugf("No prefixes found for connection ID %s", connID) + return nil + } + + var result *multierror.Error + for _, prefix := range prefixes { + ref := rm.refCountMap[prefix] + log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix) + if ref.count == 1 { + log.Debugf("Removing route for prefix %s", prefix) + // TODO: don't fail if the route is not found + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + continue + } + delete(rm.refCountMap, prefix) + } else { + ref.count-- + rm.refCountMap[prefix] = ref + } + } + delete(rm.prefixMap, connID) + + return result.ErrorOrNil() +} + +// Flush removes all references and routes from the system +func (rm *RouteManager) Flush() error { + rm.mutex.Lock() + defer rm.mutex.Unlock() + + var result *multierror.Error + for prefix := range rm.refCountMap { + log.Debugf("Removing route for prefix %s", prefix) + ref := rm.refCountMap[prefix] + if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil { + result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err)) + } + } + rm.refCountMap = map[netip.Prefix]ref{} + rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{} + + return result.ErrorOrNil() +} diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 00df735fb..af82dc913 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -155,11 +155,13 @@ func (m *defaultServerRouter) cleanUp() { log.Errorf("Failed to remove cleanup route: %v", err) } - state := m.statusRecorder.GetLocalPeerState() - state.Routes = nil - m.statusRecorder.UpdateLocalPeerState(state) } + + state := m.statusRecorder.GetLocalPeerState() + state.Routes = nil + m.statusRecorder.UpdateLocalPeerState(state) } + func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) { parsed, err := netip.ParsePrefix(source) if err != nil { diff --git a/client/internal/routemanager/systemops.go b/client/internal/routemanager/systemops.go new file mode 100644 index 000000000..1ee54b746 --- /dev/null +++ b/client/internal/routemanager/systemops.go @@ -0,0 +1,428 @@ +//go:build !android && !ios + +package routemanager + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "runtime" + "strconv" + + "github.com/hashicorp/go-multierror" + "github.com/libp2p/go-netroute" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" + nbnet "github.com/netbirdio/netbird/util/net" +) + +var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) +var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) +var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) +var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) + +var ErrRouteNotFound = errors.New("route not found") +var ErrRouteNotAllowed = errors.New("route not allowed") + +// TODO: fix: for default our wg address now appears as the default gw +func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { + addr := netip.IPv4Unspecified() + if prefix.Addr().Is6() { + addr = netip.IPv6Unspecified() + } + + defaultGateway, _, err := getNextHop(addr) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + return fmt.Errorf("get existing route gateway: %s", err) + } + + if !prefix.Contains(defaultGateway) { + log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix) + return nil + } + + gatewayPrefix := netip.PrefixFrom(defaultGateway, 32) + if defaultGateway.Is6() { + gatewayPrefix = netip.PrefixFrom(defaultGateway, 128) + } + + ok, err := existsInRouteTable(gatewayPrefix) + if err != nil { + return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) + } + + if ok { + log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) + return nil + } + + var exitIntf string + gatewayHop, intf, err := getNextHop(defaultGateway) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) + } + if intf != nil { + exitIntf = intf.Name + } + + log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) + return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf) +} + +func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) { + r, err := netroute.New() + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err) + } + intf, gateway, preferredSrc, err := r.Route(ip.AsSlice()) + if err != nil { + log.Warnf("Failed to get route for %s: %v", ip, err) + return netip.Addr{}, nil, ErrRouteNotFound + } + + log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc) + if gateway == nil { + if preferredSrc == nil { + return netip.Addr{}, nil, ErrRouteNotFound + } + log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc) + + addr, err := ipToAddr(preferredSrc, intf) + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("convert preferred source to address: %w", err) + } + return addr.Unmap(), intf, nil + } + + addr, err := ipToAddr(gateway, intf) + if err != nil { + return netip.Addr{}, nil, fmt.Errorf("convert gateway to address: %w", err) + } + + return addr, intf, nil +} + +// converts a net.IP to a netip.Addr including the zone based on the passed interface +func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return netip.Addr{}, fmt.Errorf("failed to convert IP address to netip.Addr: %s", ip) + } + + if intf != nil && (addr.IsLinkLocalMulticast() || addr.IsLinkLocalUnicast()) { + log.Tracef("Adding zone %s to address %s", intf.Name, addr) + if runtime.GOOS == "windows" { + addr = addr.WithZone(strconv.Itoa(intf.Index)) + } else { + addr = addr.WithZone(intf.Name) + } + } + + return addr.Unmap(), nil +} + +func existsInRouteTable(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute == prefix { + return true, nil + } + } + return false, nil +} + +func isSubRange(prefix netip.Prefix) (bool, error) { + routes, err := getRoutesFromTable() + if err != nil { + return false, fmt.Errorf("get routes from table: %w", err) + } + for _, tableRoute := range routes { + if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { + return true, nil + } + } + return false, nil +} + +// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. +// If the next hop or interface is pointing to the VPN interface, it will return the initial values. +func addRouteToNonVPNIntf( + prefix netip.Prefix, + vpnIntf *iface.WGIface, + initialNextHop netip.Addr, + initialIntf *net.Interface, +) (netip.Addr, string, error) { + addr := prefix.Addr() + switch { + case addr.IsLoopback(), + addr.IsLinkLocalUnicast(), + addr.IsLinkLocalMulticast(), + addr.IsInterfaceLocalMulticast(), + addr.IsUnspecified(), + addr.IsMulticast(): + + return netip.Addr{}, "", ErrRouteNotAllowed + } + + // Determine the exit interface and next hop for the prefix, so we can add a specific route + nexthop, intf, err := getNextHop(addr) + if err != nil { + return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err) + } + + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf) + exitNextHop := nexthop + var exitIntf string + if intf != nil { + exitIntf = intf.Name + } + + vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP) + if !ok { + return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr") + } + + // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values + if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() { + log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) + exitNextHop = initialNextHop + if initialIntf != nil { + exitIntf = initialIntf.Name + } + } + + log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop) + if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil { + return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err) + } + + return exitNextHop, exitIntf, nil +} + +// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix +// in two /1 prefixes to avoid replacing the existing default route +func genericAddVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + return err + } + if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return err + } + + // TODO: remove once IPv6 is supported on the interface + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } else if prefix == defaultv6 { + if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + return fmt.Errorf("add unreachable route split 1: %w", err) + } + if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil { + log.Warnf("Failed to rollback route addition: %s", err2) + } + return fmt.Errorf("add unreachable route split 2: %w", err) + } + + return nil + } + + return addNonExistingRoute(prefix, intf) +} + +// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table +func addNonExistingRoute(prefix netip.Prefix, intf string) error { + ok, err := existsInRouteTable(prefix) + if err != nil { + return fmt.Errorf("exists in route table: %w", err) + } + if ok { + log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) + return nil + } + + ok, err = isSubRange(prefix) + if err != nil { + return fmt.Errorf("sub range: %w", err) + } + + if ok { + err := addRouteForCurrentDefaultGateway(prefix) + if err != nil { + log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) + } + } + + return addToRouteTable(prefix, netip.Addr{}, intf) +} + +// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given, +// it will remove the split /1 prefixes +func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error { + if prefix == defaultv4 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + // TODO: remove once IPv6 is supported on the interface + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } else if prefix == defaultv6 { + var result *multierror.Error + if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil { + result = multierror.Append(result, err) + } + + return result.ErrorOrNil() + } + + return removeFromRouteTable(prefix, netip.Addr{}, intf) +} + +func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + return nil, fmt.Errorf("parse IP address: %s", ip) + } + addr = addr.Unmap() + + var prefixLength int + switch { + case addr.Is4(): + prefixLength = 32 + case addr.Is6(): + prefixLength = 128 + default: + return nil, fmt.Errorf("invalid IP address: %s", addr) + } + + prefix := netip.PrefixFrom(addr, prefixLength) + return &prefix, nil +} + +func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified()) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + log.Errorf("Unable to get initial v4 default next hop: %v", err) + } + initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified()) + if err != nil && !errors.Is(err, ErrRouteNotFound) { + log.Errorf("Unable to get initial v6 default next hop: %v", err) + } + + *routeManager = NewRouteManager( + func(prefix netip.Prefix) (netip.Addr, string, error) { + addr := prefix.Addr() + nexthop, intf := initialNextHopV4, initialIntfV4 + if addr.Is6() { + nexthop, intf = initialNextHopV6, initialIntfV6 + } + return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf) + }, + removeFromRouteTable, + ) + + return setupHooks(*routeManager, initAddresses) +} + +func cleanupRoutingWithRouteManager(routeManager *RouteManager) error { + if routeManager == nil { + return nil + } + + // TODO: Remove hooks selectively + nbnet.RemoveDialerHooks() + nbnet.RemoveListenerHooks() + + if err := routeManager.Flush(); err != nil { + return fmt.Errorf("flush route manager: %w", err) + } + + return nil +} + +func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { + prefix, err := getPrefixFromIP(ip) + if err != nil { + return fmt.Errorf("convert ip to prefix: %w", err) + } + + if err := routeManager.AddRouteRef(connID, *prefix); err != nil { + return fmt.Errorf("adding route reference: %v", err) + } + + return nil + } + afterHook := func(connID nbnet.ConnectionID) error { + if err := routeManager.RemoveRouteRef(connID); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + return nil + } + + for _, ip := range initAddresses { + if err := beforeHook("init", ip); err != nil { + log.Errorf("Failed to add route reference: %v", err) + } + } + + nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error { + if ctx.Err() != nil { + return ctx.Err() + } + + var result *multierror.Error + for _, ip := range resolvedIPs { + result = multierror.Append(result, beforeHook(connID, ip.IP)) + } + return result.ErrorOrNil() + }) + + nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { + return afterHook(connID) + }) + + nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error { + return beforeHook(connID, ip.IP) + }) + + nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error { + return afterHook(connID) + }) + + return beforeHook, afterHook, nil +} diff --git a/client/internal/routemanager/systemops_android.go b/client/internal/routemanager/systemops_android.go index 291826780..34d2d270f 100644 --- a/client/internal/routemanager/systemops_android.go +++ b/client/internal/routemanager/systemops_android.go @@ -1,13 +1,33 @@ package routemanager import ( + "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + +func cleanupRouting() error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { + return nil +} + +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_bsd.go b/client/internal/routemanager/systemops_bsd.go index 173e7c0e8..b6a2006e7 100644 --- a/client/internal/routemanager/systemops_bsd.go +++ b/client/internal/routemanager/systemops_bsd.go @@ -8,6 +8,7 @@ import ( "net/netip" "syscall" + log "github.com/sirupsen/logrus" "golang.org/x/net/route" ) @@ -51,16 +52,24 @@ func getRoutesFromTable() ([]netip.Prefix, error) { continue } + if len(m.Addrs) < 3 { + log.Warnf("Unexpected RIB message Addrs: %v", m.Addrs) + continue + } + addr, ok := toNetIPAddr(m.Addrs[0]) if !ok { continue } - mask, ok := toNetIPMASK(m.Addrs[2]) - if !ok { - continue + cidr := 32 + if mask := m.Addrs[2]; mask != nil { + cidr, ok = toCIDR(mask) + if !ok { + log.Debugf("Unexpected RIB message Addrs[2]: %v", mask) + continue + } } - cidr, _ := mask.Size() routePrefix := netip.PrefixFrom(addr, cidr) if routePrefix.IsValid() { @@ -73,20 +82,19 @@ func getRoutesFromTable() ([]netip.Prefix, error) { func toNetIPAddr(a route.Addr) (netip.Addr, bool) { switch t := a.(type) { case *route.Inet4Addr: - ip := net.IPv4(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) - addr := netip.MustParseAddr(ip.String()) - return addr, true + return netip.AddrFrom4(t.IP), true default: return netip.Addr{}, false } } -func toNetIPMASK(a route.Addr) (net.IPMask, bool) { +func toCIDR(a route.Addr) (int, bool) { switch t := a.(type) { case *route.Inet4Addr: mask := net.IPv4Mask(t.IP[0], t.IP[1], t.IP[2], t.IP[3]) - return mask, true + cidr, _ := mask.Size() + return cidr, true default: - return nil, false + return 0, false } } diff --git a/client/internal/routemanager/systemops_bsd_nonios.go b/client/internal/routemanager/systemops_bsd_nonios.go deleted file mode 100644 index f60c7afc3..000000000 --- a/client/internal/routemanager/systemops_bsd_nonios.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && !ios - -package routemanager - -import "net/netip" - -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - return genericAddToRouteTableIfNoExists(prefix, addr, intf) -} - -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) -} diff --git a/client/internal/routemanager/systemops_darwin.go b/client/internal/routemanager/systemops_darwin.go new file mode 100644 index 000000000..f7ce72a4e --- /dev/null +++ b/client/internal/routemanager/systemops_darwin.go @@ -0,0 +1,89 @@ +//go:build darwin && !ios + +package routemanager + +import ( + "fmt" + "net" + "net/netip" + "os/exec" + "strings" + "time" + + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" +) + +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +} + +func cleanupRouting() error { + return cleanupRoutingWithRouteManager(routeManager) +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("add", prefix, nexthop, intf) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return routeCmd("delete", prefix, nexthop, intf) +} + +func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error { + inet := "-inet" + network := prefix.String() + if prefix.IsSingleIP() { + network = prefix.Addr().String() + } + if prefix.Addr().Is6() { + inet = "-inet6" + // Special case for IPv6 split default route, pointing to the wg interface fails + // TODO: Remove once we have IPv6 support on the interface + if prefix.Bits() == 1 { + intf = "lo0" + } + } + + args := []string{"-n", action, inet, network} + if nexthop.IsValid() { + args = append(args, nexthop.Unmap().String()) + } else if intf != "" { + args = append(args, "-interface", intf) + } + + if err := retryRouteCmd(args); err != nil { + return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err) + } + return nil +} + +func retryRouteCmd(args []string) error { + operation := func() error { + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s: %s", strings.Join(args, " "), out) + // https://github.com/golang/go/issues/45736 + if err != nil && strings.Contains(string(out), "sysctl: cannot allocate memory") { + return err + } else if err != nil { + return backoff.Permanent(err) + } + return nil + } + + expBackOff := backoff.NewExponentialBackOff() + expBackOff.InitialInterval = 50 * time.Millisecond + expBackOff.MaxInterval = 500 * time.Millisecond + expBackOff.MaxElapsedTime = 1 * time.Second + + err := backoff.Retry(operation, expBackOff) + if err != nil { + return fmt.Errorf("route cmd retry failed: %w", err) + } + return nil +} diff --git a/client/internal/routemanager/systemops_darwin_test.go b/client/internal/routemanager/systemops_darwin_test.go new file mode 100644 index 000000000..cc9bb9db5 --- /dev/null +++ b/client/internal/routemanager/systemops_darwin_test.go @@ -0,0 +1,138 @@ +//go:build !ios + +package routemanager + +import ( + "fmt" + "net" + "net/netip" + "os/exec" + "regexp" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var expectedVPNint = "utun100" +var expectedExternalInt = "lo0" +var expectedInternalInt = "lo0" + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via vpn", + destination: "10.10.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53), + }, + }...) +} + +func TestConcurrentRoutes(t *testing.T) { + baseIP := netip.MustParseAddr("192.0.2.0") + intf := "lo0" + + var wg sync.WaitGroup + for i := 0; i < 1024; i++ { + wg.Add(1) + go func(ip netip.Addr) { + defer wg.Done() + prefix := netip.PrefixFrom(ip, 32) + if err := addToRouteTable(prefix, netip.Addr{}, intf); err != nil { + t.Errorf("Failed to add route for %s: %v", prefix, err) + } + }(baseIP) + baseIP = baseIP.Next() + } + + wg.Wait() + + baseIP = netip.MustParseAddr("192.0.2.0") + + for i := 0; i < 1024; i++ { + wg.Add(1) + go func(ip netip.Addr) { + defer wg.Done() + prefix := netip.PrefixFrom(ip, 32) + if err := removeFromRouteTable(prefix, netip.Addr{}, intf); err != nil { + t.Errorf("Failed to remove route for %s: %v", prefix, err) + } + }(baseIP) + baseIP = baseIP.Next() + } + + wg.Wait() +} + +func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string { + t.Helper() + + err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run() + require.NoError(t, err, "Failed to create loopback alias") + + t.Cleanup(func() { + err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run() + assert.NoError(t, err, "Failed to remove loopback alias") + }) + + return "lo0" +} + +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) { + t.Helper() + + var originalNexthop net.IP + if dstCIDR == "0.0.0.0/0" { + var err error + originalNexthop, err = fetchOriginalGateway() + if err != nil { + t.Logf("Failed to fetch original gateway: %v", err) + } + + if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil { + t.Logf("Failed to delete route: %v, output: %s", err, output) + } + } + + t.Cleanup(func() { + if originalNexthop != nil { + err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run() + assert.NoError(t, err, "Failed to restore original route") + } + }) + + err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run() + require.NoError(t, err, "Failed to add route") + + t.Cleanup(func() { + err := exec.Command("route", "delete", "-net", dstCIDR).Run() + assert.NoError(t, err, "Failed to remove route") + }) +} + +func fetchOriginalGateway() (net.IP, error) { + output, err := exec.Command("route", "-n", "get", "default").CombinedOutput() + if err != nil { + return nil, err + } + + matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output)) + if len(matches) == 0 { + return nil, fmt.Errorf("gateway not found") + } + + return net.ParseIP(matches[1]), nil +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24") + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) + + otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24") + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) +} diff --git a/client/internal/routemanager/systemops_ios.go b/client/internal/routemanager/systemops_ios.go index 291826780..34d2d270f 100644 --- a/client/internal/routemanager/systemops_ios.go +++ b/client/internal/routemanager/systemops_ios.go @@ -1,13 +1,33 @@ package routemanager import ( + "net" "net/netip" + "runtime" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) -func addToRouteTableIfNoExists(prefix netip.Prefix, addr, intf string) error { +func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return nil, nil, nil +} + +func cleanupRouting() error { return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr, intf string) error { +func enableIPForwarding() error { + log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) + return nil +} + +func addVPNRoute(netip.Prefix, string) error { + return nil +} + +func removeVPNRoute(netip.Prefix, string) error { return nil } diff --git a/client/internal/routemanager/systemops_linux.go b/client/internal/routemanager/systemops_linux.go index 192509992..dd00626e1 100644 --- a/client/internal/routemanager/systemops_linux.go +++ b/client/internal/routemanager/systemops_linux.go @@ -4,17 +4,21 @@ package routemanager import ( "bufio" + "context" "errors" "fmt" "net" "net/netip" "os" "syscall" + "time" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -33,6 +37,9 @@ const ( var ErrTableIDExists = errors.New("ID exists with different name") +var routeManager = &RouteManager{} +var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true" + type ruleParams struct { fwmark int tableID int @@ -45,10 +52,10 @@ type ruleParams struct { func getSetupRules() []ruleParams { return []ruleParams{ - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "add rule v4 netbird"}, - {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "add rule v6 netbird"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "add rule with suppress prefixlen v4"}, - {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "add rule with suppress prefixlen v6"}, + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"}, + {nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"}, + {-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"}, } } @@ -64,7 +71,12 @@ func getSetupRules() []ruleParams { // enabling VPN connectivity. // // The rules are inserted in reverse order, as rules are added from the bottom up in the rule list. -func setupRouting() (err error) { +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) { + if isLegacy { + log.Infof("Using legacy routing setup") + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + } + if err = addRoutingTableName(); err != nil { log.Errorf("Error adding routing table name: %v", err) } @@ -80,17 +92,26 @@ func setupRouting() (err error) { rules := getSetupRules() for _, rule := range rules { if err := addRule(rule); err != nil { - return fmt.Errorf("%s: %w", rule.description, err) + if errors.Is(err, syscall.EOPNOTSUPP) { + log.Warnf("Rule operations are not supported, falling back to the legacy routing setup") + isLegacy = true + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) + } + return nil, nil, fmt.Errorf("%s: %w", rule.description, err) } } - return nil + return nil, nil, nil } // cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // It systematically removes the three rules and any associated routing table entries to ensure a clean state. // The function uses error aggregation to report any errors encountered during the cleanup process. func cleanupRouting() error { + if isLegacy { + return cleanupRoutingWithRouteManager(routeManager) + } + var result *multierror.Error if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { @@ -102,7 +123,7 @@ func cleanupRouting() error { rules := getSetupRules() for _, rule := range rules { - if err := removeAllRules(rule); err != nil { + if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) { result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err)) } } @@ -110,157 +131,61 @@ func cleanupRouting() error { return result.ErrorOrNil() } -func addToRouteTableIfNoExists(prefix netip.Prefix, _ string, intf string) error { - // No need to check if routes exist as main table takes precedence over the VPN table via Rule 2 +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN) +} + +func addVPNRoute(prefix netip.Prefix, intf string) error { + if isLegacy { + return genericAddVPNRoute(prefix, intf) + } + + // No need to check if routes exist as main table takes precedence over the VPN table via Rule 1 // TODO remove this once we have ipv6 support if prefix == defaultv4 { - if err := addUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { return fmt.Errorf("add blackhole: %w", err) } } - if err := addRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { return fmt.Errorf("add route: %w", err) } return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, _ string, intf string) error { +func removeVPNRoute(prefix netip.Prefix, intf string) error { + if isLegacy { + return genericRemoveVPNRoute(prefix, intf) + } + // TODO remove this once we have ipv6 support if prefix == defaultv4 { - if err := removeUnreachableRoute(&defaultv6, NetbirdVPNTableID, netlink.FAMILY_V6); err != nil { + if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil { return fmt.Errorf("remove unreachable route: %w", err) } } - if err := removeRoute(&prefix, nil, &intf, NetbirdVPNTableID, netlink.FAMILY_V4); err != nil { + if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil { return fmt.Errorf("remove route: %w", err) } return nil } func getRoutesFromTable() ([]netip.Prefix, error) { - return getRoutes(NetbirdVPNTableID, netlink.FAMILY_V4) -} - -// addRoute adds a route to a specific routing table identified by tableID. -func addRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { - route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Table: tableID, - Family: family, - } - - if prefix != nil { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - route.Dst = ipNet - } - - if err := addNextHop(addr, intf, route); err != nil { - return fmt.Errorf("add gateway and device: %w", err) - } - - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) { - return fmt.Errorf("netlink add route: %w", err) - } - - return nil -} - -// addUnreachableRoute adds an unreachable route for the specified IP family and routing table. -// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. -// tableID specifies the routing table to which the unreachable route will be added. -func addUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) + v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4) if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) + return nil, fmt.Errorf("get v4 routes: %w", err) } - - route := &netlink.Route{ - Type: syscall.RTN_UNREACHABLE, - Table: tableID, - Family: ipFamily, - Dst: ipNet, - } - - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) { - return fmt.Errorf("netlink add unreachable route: %w", err) - } - - return nil -} - -func removeUnreachableRoute(prefix *netip.Prefix, tableID, ipFamily int) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) + v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6) if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) + return nil, fmt.Errorf("get v6 routes: %w", err) + } - - route := &netlink.Route{ - Type: syscall.RTN_UNREACHABLE, - Table: tableID, - Family: ipFamily, - Dst: ipNet, - } - - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { - return fmt.Errorf("netlink remove unreachable route: %w", err) - } - - return nil - -} - -// removeRoute removes a route from a specific routing table identified by tableID. -func removeRoute(prefix *netip.Prefix, addr, intf *string, tableID, family int) error { - _, ipNet, err := net.ParseCIDR(prefix.String()) - if err != nil { - return fmt.Errorf("parse prefix %s: %w", prefix, err) - } - - route := &netlink.Route{ - Scope: netlink.SCOPE_UNIVERSE, - Table: tableID, - Family: family, - Dst: ipNet, - } - - if err := addNextHop(addr, intf, route); err != nil { - return fmt.Errorf("add gateway and device: %w", err) - } - - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) { - return fmt.Errorf("netlink remove route: %w", err) - } - - return nil -} - -func flushRoutes(tableID, family int) error { - routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) - if err != nil { - return fmt.Errorf("list routes from table %d: %w", tableID, err) - } - - var result *multierror.Error - for i := range routes { - route := routes[i] - // unreachable default routes don't come back with Dst set - if route.Gw == nil && route.Src == nil && route.Dst == nil { - if family == netlink.FAMILY_V4 { - routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)} - } else { - routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} - } - } - if err := netlink.RouteDel(&routes[i]); err != nil { - result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) - } - } - - return result.ErrorOrNil() + return append(v4Routes, v6Routes...), nil } // getRoutes fetches routes from a specific routing table identified by tableID. @@ -291,6 +216,125 @@ func getRoutes(tableID, family int) ([]netip.Prefix, error) { return prefixList, nil } +// addRoute adds a route to a specific routing table identified by tableID. +func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Table: tableID, + Family: getAddressFamily(prefix), + } + + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + route.Dst = ipNet + + if err := addNextHop(addr, intf, route); err != nil { + return fmt.Errorf("add gateway and device: %w", err) + } + + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink add route: %w", err) + } + + return nil +} + +// addUnreachableRoute adds an unreachable route for the specified IP family and routing table. +// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6. +// tableID specifies the routing table to which the unreachable route will be added. +func addUnreachableRoute(prefix netip.Prefix, tableID int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + + route := &netlink.Route{ + Type: syscall.RTN_UNREACHABLE, + Table: tableID, + Family: getAddressFamily(prefix), + Dst: ipNet, + } + + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink add unreachable route: %w", err) + } + + return nil +} + +func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + + route := &netlink.Route{ + Type: syscall.RTN_UNREACHABLE, + Table: tableID, + Family: getAddressFamily(prefix), + Dst: ipNet, + } + + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink remove unreachable route: %w", err) + } + + return nil + +} + +// removeRoute removes a route from a specific routing table identified by tableID. +func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error { + _, ipNet, err := net.ParseCIDR(prefix.String()) + if err != nil { + return fmt.Errorf("parse prefix %s: %w", prefix, err) + } + + route := &netlink.Route{ + Scope: netlink.SCOPE_UNIVERSE, + Table: tableID, + Family: getAddressFamily(prefix), + Dst: ipNet, + } + + if err := addNextHop(addr, intf, route); err != nil { + return fmt.Errorf("add gateway and device: %w", err) + } + + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { + return fmt.Errorf("netlink remove route: %w", err) + } + + return nil +} + +func flushRoutes(tableID, family int) error { + routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE) + if err != nil { + return fmt.Errorf("list routes from table %d: %w", tableID, err) + } + + var result *multierror.Error + for i := range routes { + route := routes[i] + // unreachable default routes don't come back with Dst set + if route.Gw == nil && route.Src == nil && route.Dst == nil { + if family == netlink.FAMILY_V4 { + routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)} + } else { + routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} + } + } + if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) + } + } + + return result.ErrorOrNil() +} + func enableIPForwarding() error { bytes, err := os.ReadFile(ipv4ForwardingPath) if err != nil { @@ -385,7 +429,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("add routing rule: %w", err) } @@ -410,35 +454,58 @@ func removeRule(params ruleParams) error { } func removeAllRules(params ruleParams) error { - for { - if err := removeRule(params); err != nil { - if errors.Is(err, syscall.ENOENT) { - break + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + for { + if ctx.Err() != nil { + done <- ctx.Err() + return + } + if err := removeRule(params); err != nil { + if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) { + done <- nil + return + } + done <- err + return } - return err } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-done: + return err } - return nil } // addNextHop adds the gateway and device to the route. -func addNextHop(addr *string, intf *string, route *netlink.Route) error { - if addr != nil { - ip := net.ParseIP(*addr) - if ip == nil { - return fmt.Errorf("parsing address %s failed", *addr) +func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error { + if addr.IsValid() { + route.Gw = addr.AsSlice() + if intf == "" { + intf = addr.Zone() } - - route.Gw = ip } - if intf != nil { - link, err := netlink.LinkByName(*intf) + if intf != "" { + link, err := netlink.LinkByName(intf) if err != nil { - return fmt.Errorf("set interface %s: %w", *intf, err) + return fmt.Errorf("set interface %s: %w", intf, err) } route.LinkIndex = link.Attrs().Index } return nil } + +func getAddressFamily(prefix netip.Prefix) int { + if prefix.Addr().Is4() { + return netlink.FAMILY_V4 + } + return netlink.FAMILY_V6 +} diff --git a/client/internal/routemanager/systemops_linux_test.go b/client/internal/routemanager/systemops_linux_test.go index 96e43d20f..0043c3f4e 100644 --- a/client/internal/routemanager/systemops_linux_test.go +++ b/client/internal/routemanager/systemops_linux_test.go @@ -6,34 +6,38 @@ import ( "errors" "fmt" "net" - "net/netip" "os" "strings" "syscall" "testing" - "time" - "github.com/gopacket/gopacket" - "github.com/gopacket/gopacket/layers" - "github.com/gopacket/gopacket/pcap" - "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vishvananda/netlink" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" - nbnet "github.com/netbirdio/netbird/util/net" ) -type PacketExpectation struct { - SrcIP net.IP - DstIP net.IP - SrcPort int - DstPort int - UDP bool - TCP bool +var expectedVPNint = "wgtest0" +var expectedLoopbackInt = "lo" +var expectedExternalInt = "dummyext0" +var expectedInternalInt = "dummyint0" + +func init() { + testCases = append(testCases, []testCase{ + { + name: "To more specific route without custom dialer via physical interface", + destination: "10.10.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53), + }, + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.1:53", + expectedInterface: expectedLoopbackInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), + }, + }...) } func TestEntryExists(t *testing.T) { @@ -92,157 +96,7 @@ func TestEntryExists(t *testing.T) { } } -func TestRoutingWithTables(t *testing.T) { - testCases := []struct { - name string - destination string - captureInterface string - dialer *net.Dialer - packetExpectation PacketExpectation - }{ - { - name: "To external host without fwmark via vpn", - destination: "192.0.2.1:53", - captureInterface: "wgtest0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), - }, - { - name: "To external host with fwmark via physical interface", - destination: "192.0.2.1:53", - captureInterface: "dummyext0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), - }, - - { - name: "To duplicate internal route with fwmark via physical interface", - destination: "10.0.0.1:53", - captureInterface: "dummyint0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), - }, - { - name: "To duplicate internal route without fwmark via physical interface", // local route takes precedence - destination: "10.0.0.1:53", - captureInterface: "dummyint0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.0.0.1", 53), - }, - - { - name: "To unique vpn route with fwmark via physical interface", - destination: "172.16.0.1:53", - captureInterface: "dummyext0", - dialer: nbnet.NewDialer(), - packetExpectation: createPacketExpectation("192.168.0.1", 12345, "172.16.0.1", 53), - }, - { - name: "To unique vpn route without fwmark via vpn", - destination: "172.16.0.1:53", - captureInterface: "wgtest0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("100.64.0.1", 12345, "172.16.0.1", 53), - }, - - { - name: "To more specific route without fwmark via vpn interface", - destination: "10.10.0.1:53", - captureInterface: "dummyint0", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("192.168.1.1", 12345, "10.10.0.1", 53), - }, - - { - name: "To more specific route (local) without fwmark via physical interface", - destination: "127.0.10.1:53", - captureInterface: "lo", - dialer: &net.Dialer{}, - packetExpectation: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - wgIface, _, _ := setupTestEnv(t) - - // default route exists in main table and vpn table - err := addToRouteTableIfNoExists(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 10.0.0.0/8 route exists in main table and vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 10.10.0.0/24 more specific route exists in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // 127.0.10.0/24 more specific route exists in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - // unique route in vpn table - err = addToRouteTableIfNoExists(netip.MustParsePrefix("172.16.0.0/16"), wgIface.Address().IP.String(), wgIface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") - - filter := createBPFFilter(tc.destination) - handle := startPacketCapture(t, tc.captureInterface, filter) - - sendTestPacket(t, tc.destination, tc.packetExpectation.SrcPort, tc.dialer) - - packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) - packet, err := packetSource.NextPacket() - require.NoError(t, err) - - verifyPacket(t, packet, tc.packetExpectation) - }) - } -} - -func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { - t.Helper() - - ipLayer := packet.Layer(layers.LayerTypeIPv4) - require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") - - ip, ok := ipLayer.(*layers.IPv4) - require.True(t, ok, "Failed to cast to IPv4 layer") - - // Convert both source and destination IP addresses to 16-byte representation - expectedSrcIP := exp.SrcIP.To16() - actualSrcIP := ip.SrcIP.To16() - assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") - - expectedDstIP := exp.DstIP.To16() - actualDstIP := ip.DstIP.To16() - assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") - - if exp.UDP { - udpLayer := packet.Layer(layers.LayerTypeUDP) - require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") - - udp, ok := udpLayer.(*layers.UDP) - require.True(t, ok, "Failed to cast to UDP layer") - - assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") - assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") - } - - if exp.TCP { - tcpLayer := packet.Layer(layers.LayerTypeTCP) - require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") - - tcp, ok := tcpLayer.(*layers.TCP) - require.True(t, ok, "Failed to cast to TCP layer") - - assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") - assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") - } - -} - -func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) *netlink.Dummy { +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { t.Helper() dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}} @@ -264,35 +118,52 @@ func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR str require.NoError(t, err) } - return dummy + t.Cleanup(func() { + err := netlink.LinkDel(dummy) + assert.NoError(t, err) + }) + + return dummy.Name } -func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) { +func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) { t.Helper() _, dstIPNet, err := net.ParseCIDR(dstCIDR) require.NoError(t, err) + // Handle existing routes with metric 0 + var originalNexthop net.IP + var originalLinkIndex int if dstIPNet.String() == "0.0.0.0/0" { - gw, linkIndex, err := fetchOriginalGateway(netlink.FAMILY_V4) - if err != nil { + var err error + originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4) + if err != nil && !errors.Is(err, ErrRouteNotFound) { t.Logf("Failed to fetch original gateway: %v", err) } - // Handle existing routes with metric 0 - err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) - if err == nil { - t.Cleanup(func() { - err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: gw, LinkIndex: linkIndex, Priority: 0}) - if err != nil && !errors.Is(err, syscall.EEXIST) { - t.Fatalf("Failed to add route: %v", err) - } - }) - } else if !errors.Is(err, syscall.ESRCH) { - t.Logf("Failed to delete route: %v", err) + if originalNexthop != nil { + err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0}) + switch { + case err != nil && !errors.Is(err, syscall.ESRCH): + t.Logf("Failed to delete route: %v", err) + case err == nil: + t.Cleanup(func() { + err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0}) + if err != nil && !errors.Is(err, syscall.EEXIST) { + t.Fatalf("Failed to add route: %v", err) + } + }) + default: + t.Logf("Failed to delete route: %v", err) + } } } + link, err := netlink.LinkByName(intf) + require.NoError(t, err) + linkIndex := link.Attrs().Index + route := &netlink.Route{ Dst: dstIPNet, Gw: gw, @@ -307,9 +178,9 @@ func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, linkIndex int) { if err != nil && !errors.Is(err, syscall.EEXIST) { t.Fatalf("Failed to add route: %v", err) } + require.NoError(t, err) } -// fetchOriginalGateway returns the original gateway IP address and the interface index. func fetchOriginalGateway(family int) (net.IP, int, error) { routes, err := netlink.RouteList(nil, family) if err != nil { @@ -317,153 +188,20 @@ func fetchOriginalGateway(family int) (net.IP, int, error) { } for _, route := range routes { - if route.Dst == nil { + if route.Dst == nil && route.Priority == 0 { return route.Gw, route.LinkIndex, nil } } - return nil, 0, fmt.Errorf("default route not found") + return nil, 0, ErrRouteNotFound } -func setupDummyInterfacesAndRoutes(t *testing.T) (string, string) { +func setupDummyInterfacesAndRoutes(t *testing.T) { t.Helper() defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24") - addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy.Attrs().Index) + addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy) otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24") - addDummyRoute(t, "10.0.0.0/8", nil, otherDummy.Attrs().Index) - - t.Cleanup(func() { - err := netlink.LinkDel(defaultDummy) - assert.NoError(t, err) - err = netlink.LinkDel(otherDummy) - assert.NoError(t, err) - }) - - return defaultDummy.Name, otherDummy.Name -} - -func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { - t.Helper() - - peerPrivateKey, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) - - newNet, err := stdnet.NewNet(nil) - require.NoError(t, err) - - wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) - require.NoError(t, err, "should create testing WireGuard interface") - - err = wgInterface.Create() - require.NoError(t, err, "should create testing WireGuard interface") - - t.Cleanup(func() { - wgInterface.Close() - }) - - return wgInterface -} - -func setupTestEnv(t *testing.T) (*iface.WGIface, string, string) { - t.Helper() - - defaultDummy, otherDummy := setupDummyInterfacesAndRoutes(t) - - wgIface := createWGInterface(t, "wgtest0", "100.64.0.1/24", 51820) - t.Cleanup(func() { - assert.NoError(t, wgIface.Close()) - }) - - err := setupRouting() - require.NoError(t, err, "setupRouting should not return err") - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - return wgIface, defaultDummy, otherDummy -} - -func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { - t.Helper() - - inactive, err := pcap.NewInactiveHandle(intf) - require.NoError(t, err, "Failed to create inactive pcap handle") - defer inactive.CleanUp() - - err = inactive.SetSnapLen(1600) - require.NoError(t, err, "Failed to set snap length on inactive handle") - - err = inactive.SetTimeout(time.Second * 10) - require.NoError(t, err, "Failed to set timeout on inactive handle") - - err = inactive.SetImmediateMode(true) - require.NoError(t, err, "Failed to set immediate mode on inactive handle") - - handle, err := inactive.Activate() - require.NoError(t, err, "Failed to activate pcap handle") - t.Cleanup(handle.Close) - - err = handle.SetBPFFilter(filter) - require.NoError(t, err, "Failed to set BPF filter") - - return handle -} - -func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer *net.Dialer) { - t.Helper() - - if dialer == nil { - dialer = &net.Dialer{} - } - - if sourcePort != 0 { - localUDPAddr := &net.UDPAddr{ - IP: net.IPv4zero, - Port: sourcePort, - } - dialer.LocalAddr = localUDPAddr - } - - msg := new(dns.Msg) - msg.Id = dns.Id() - msg.RecursionDesired = true - msg.Question = []dns.Question{ - {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - - conn, err := dialer.Dial("udp", destination) - require.NoError(t, err, "Failed to dial UDP") - defer conn.Close() - - data, err := msg.Pack() - require.NoError(t, err, "Failed to pack DNS message") - - _, err = conn.Write(data) - if err != nil { - if strings.Contains(err.Error(), "required key not available") { - t.Logf("Ignoring WireGuard key error: %v", err) - return - } - t.Fatalf("Failed to send DNS query: %v", err) - } -} - -func createBPFFilter(destination string) string { - host, port, err := net.SplitHostPort(destination) - if err != nil { - return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) - } - return "udp" -} - -func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { - return PacketExpectation{ - SrcIP: net.ParseIP(srcIP), - DstIP: net.ParseIP(dstIP), - SrcPort: srcPort, - DstPort: dstPort, - UDP: true, - } + addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy) } diff --git a/client/internal/routemanager/systemops_nonandroid.go b/client/internal/routemanager/systemops_nonandroid.go deleted file mode 100644 index 65f670ace..000000000 --- a/client/internal/routemanager/systemops_nonandroid.go +++ /dev/null @@ -1,148 +0,0 @@ -//go:build !android - -//nolint:unused -package routemanager - -import ( - "errors" - "fmt" - "net" - "net/netip" - "os/exec" - "runtime" - - "github.com/libp2p/go-netroute" - log "github.com/sirupsen/logrus" -) - -var errRouteNotFound = fmt.Errorf("route not found") - -func genericAddRouteForCurrentDefaultGateway(prefix netip.Prefix) error { - defaultGateway, err := getExistingRIBRouteGateway(defaultv4) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("get existing route gateway: %s", err) - } - - addr := netip.MustParseAddr(defaultGateway.String()) - - if !prefix.Contains(addr) { - log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix) - return nil - } - - gatewayPrefix := netip.PrefixFrom(addr, 32) - - ok, err := existsInRouteTable(gatewayPrefix) - if err != nil { - return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err) - } - - if ok { - log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix) - return nil - } - - gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix) - if err != nil && !errors.Is(err, errRouteNotFound) { - return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err) - } - log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop) - return genericAddToRouteTable(gatewayPrefix, gatewayHop.String(), "") -} - -func genericAddToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - ok, err := existsInRouteTable(prefix) - if err != nil { - return fmt.Errorf("exists in route table: %w", err) - } - if ok { - log.Warnf("Skipping adding a new route for network %s because it already exists", prefix) - return nil - } - - ok, err = isSubRange(prefix) - if err != nil { - return fmt.Errorf("sub range: %w", err) - } - - if ok { - err := genericAddRouteForCurrentDefaultGateway(prefix) - if err != nil { - log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err) - } - } - - return genericAddToRouteTable(prefix, addr, intf) -} - -func genericRemoveFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - return genericRemoveFromRouteTable(prefix, addr, intf) -} - -func genericAddToRouteTable(prefix netip.Prefix, addr, _ string) error { - cmd := exec.Command("route", "add", prefix.String(), addr) - out, err := cmd.Output() - if err != nil { - return fmt.Errorf("add route: %w", err) - } - log.Debugf(string(out)) - return nil -} - -func genericRemoveFromRouteTable(prefix netip.Prefix, addr, _ string) error { - args := []string{"delete", prefix.String()} - if runtime.GOOS == "darwin" { - args = append(args, addr) - } - cmd := exec.Command("route", args...) - out, err := cmd.Output() - if err != nil { - return fmt.Errorf("remove route: %w", err) - } - log.Debugf(string(out)) - return nil -} - -func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) { - r, err := netroute.New() - if err != nil { - return nil, fmt.Errorf("new netroute: %w", err) - } - _, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice()) - if err != nil { - log.Errorf("Getting routes returned an error: %v", err) - return nil, errRouteNotFound - } - - if gateway == nil { - return preferredSrc, nil - } - - return gateway, nil -} - -func existsInRouteTable(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if tableRoute == prefix { - return true, nil - } - } - return false, nil -} - -func isSubRange(prefix netip.Prefix) (bool, error) { - routes, err := getRoutesFromTable() - if err != nil { - return false, fmt.Errorf("get routes from table: %w", err) - } - for _, tableRoute := range routes { - if isPrefixSupported(tableRoute) && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() { - return true, nil - } - } - return false, nil -} diff --git a/client/internal/routemanager/systemops_nonlinux.go b/client/internal/routemanager/systemops_nonlinux.go index d793f0fbd..38026107e 100644 --- a/client/internal/routemanager/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops_nonlinux.go @@ -1,22 +1,23 @@ -//go:build !linux || android +//go:build !linux && !ios package routemanager import ( + "net/netip" "runtime" log "github.com/sirupsen/logrus" ) -func setupRouting() error { - return nil -} - -func cleanupRouting() error { - return nil -} - func enableIPForwarding() error { log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS) return nil } + +func addVPNRoute(prefix netip.Prefix, intf string) error { + return genericAddVPNRoute(prefix, intf) +} + +func removeVPNRoute(prefix netip.Prefix, intf string) error { + return genericRemoveVPNRoute(prefix, intf) +} diff --git a/client/internal/routemanager/systemops_nonlinux_test.go b/client/internal/routemanager/systemops_nonlinux_test.go deleted file mode 100644 index afaf5ba77..000000000 --- a/client/internal/routemanager/systemops_nonlinux_test.go +++ /dev/null @@ -1,80 +0,0 @@ -//go:build !linux || android - -package routemanager - -import ( - "net" - "net/netip" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestIsSubRange(t *testing.T) { - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var subRangeAddressPrefixes []netip.Prefix - var nonSubRangeAddressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 { - p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1) - subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2) - nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked()) - } - } - - for _, prefix := range subRangeAddressPrefixes { - isSubRangePrefix, err := isSubRange(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address is sub-range: ", err) - } - if !isSubRangePrefix { - t.Fatalf("address %s should be sub-range of an existing route in the table", prefix) - } - } - - for _, prefix := range nonSubRangeAddressPrefixes { - isSubRangePrefix, err := isSubRange(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address is sub-range: ", err) - } - if isSubRangePrefix { - t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix) - } - } -} - -func TestExistsInRouteTable(t *testing.T) { - require.NoError(t, setupRouting()) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - addresses, err := net.InterfaceAddrs() - if err != nil { - t.Fatal("shouldn't return error when fetching interface addresses: ", err) - } - - var addressPrefixes []netip.Prefix - for _, address := range addresses { - p := netip.MustParsePrefix(address.String()) - if p.Addr().Is4() { - addressPrefixes = append(addressPrefixes, p.Masked()) - } - } - - for _, prefix := range addressPrefixes { - exists, err := existsInRouteTable(prefix) - if err != nil { - t.Fatal("shouldn't return error when checking if address exists in route table: ", err) - } - if !exists { - t.Fatalf("address %s should exist in route table", prefix) - } - } -} diff --git a/client/internal/routemanager/systemops_nonandroid_test.go b/client/internal/routemanager/systemops_test.go similarity index 50% rename from client/internal/routemanager/systemops_nonandroid_test.go rename to client/internal/routemanager/systemops_test.go index aae5e5faa..97386f19a 100644 --- a/client/internal/routemanager/systemops_nonandroid_test.go +++ b/client/internal/routemanager/systemops_test.go @@ -1,14 +1,14 @@ -//go:build !android +//go:build !android && !ios package routemanager import ( "bytes" + "context" "fmt" "net" "net/netip" "os" - "os/exec" "runtime" "strings" "testing" @@ -22,47 +22,9 @@ import ( "github.com/netbirdio/netbird/iface" ) -func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { - t.Helper() - - if runtime.GOOS == "linux" { - outIntf, err := getOutgoingInterfaceLinux(prefix.Addr().String()) - require.NoError(t, err, "getOutgoingInterfaceLinux should not return error") - if invert { - require.NotEqual(t, wgIface.Name(), outIntf, "outgoing interface should not be the wireguard interface") - } else { - require.Equal(t, wgIface.Name(), outIntf, "outgoing interface should be the wireguard interface") - } - return - } - - prefixGateway, err := getExistingRIBRouteGateway(prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") - if invert { - assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") - } else { - assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") - } -} - -func getOutgoingInterfaceLinux(destination string) (string, error) { - cmd := exec.Command("ip", "route", "get", destination) - output, err := cmd.Output() - if err != nil { - return "", fmt.Errorf("executing ip route get: %w", err) - } - - return parseOutgoingInterface(string(output)), nil -} - -func parseOutgoingInterface(routeGetOutput string) string { - fields := strings.Fields(routeGetOutput) - for i, field := range fields { - if field == "dev" && i+1 < len(fields) { - return fields[i+1] - } - } - return "" +type dialer interface { + Dial(network, address string) (net.Conn, error) + DialContext(ctx context.Context, network, address string) (net.Conn, error) } func TestAddRemoveRoutes(t *testing.T) { @@ -99,14 +61,14 @@ func TestAddRemoveRoutes(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - - require.NoError(t, setupRouting()) + _, _, err = setupRouting(nil, nil) + require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, cleanupRouting()) }) - err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) - require.NoError(t, err, "addToRouteTableIfNoExists should not return err") + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericAddVPNRoute should not return err") if testCase.shouldRouteToWireguard { assertWGOutInterface(t, testCase.prefix, wgInterface, false) @@ -116,13 +78,13 @@ func TestAddRemoveRoutes(t *testing.T) { exists, err := existsInRouteTable(testCase.prefix) require.NoError(t, err, "existsInRouteTable should not return err") if exists && testCase.shouldRouteToWireguard { - err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String(), wgInterface.Name()) - require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err") + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) + require.NoError(t, err, "genericRemoveVPNRoute should not return err") - prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix) - require.NoError(t, err, "getExistingRIBRouteGateway should not return err") + prefixGateway, _, err := getNextHop(testCase.prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") - internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) + internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) require.NoError(t, err) if testCase.shouldBeRemoved { @@ -135,12 +97,12 @@ func TestAddRemoveRoutes(t *testing.T) { } } -func TestGetExistingRIBRouteGateway(t *testing.T) { - gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) +func TestGetNextHop(t *testing.T) { + gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) } - if gateway == nil { + if !gateway.IsValid() { t.Fatal("should return a gateway") } addresses, err := net.InterfaceAddrs() @@ -162,11 +124,11 @@ func TestGetExistingRIBRouteGateway(t *testing.T) { } } - localIP, err := getExistingRIBRouteGateway(testingPrefix) + localIP, _, err := getNextHop(testingPrefix.Addr()) if err != nil { t.Fatal("shouldn't return error: ", err) } - if localIP == nil { + if !localIP.IsValid() { t.Fatal("should return a gateway for local network") } if localIP.String() == gateway.String() { @@ -177,8 +139,8 @@ func TestGetExistingRIBRouteGateway(t *testing.T) { } } -func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { - defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0")) +func TestAddExistAndRemoveRoute(t *testing.T) { + defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0")) t.Log("defaultGateway: ", defaultGateway) if err != nil { t.Fatal("shouldn't return error when fetching the gateway: ", err) @@ -238,21 +200,14 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { err = wgInterface.Create() require.NoError(t, err, "should create testing wireguard interface") - require.NoError(t, setupRouting()) - t.Cleanup(func() { - assert.NoError(t, cleanupRouting()) - }) - - MockAddr := wgInterface.Address().IP.String() - // Prepare the environment if testCase.preExistingPrefix.IsValid() { - err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr, wgInterface.Name()) + err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding pre-existing route") } // Add the route - err = addToRouteTableIfNoExists(testCase.prefix, MockAddr, wgInterface.Name()) + err = genericAddVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err when adding route") if testCase.shouldAddRoute { @@ -262,7 +217,7 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { require.True(t, ok, "route should exist") // remove route again if added - err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr, wgInterface.Name()) + err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name()) require.NoError(t, err, "should not return err") } @@ -272,11 +227,176 @@ func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) { t.Log("Buffer string: ", buf.String()) require.NoError(t, err, "should not return err") - // Linux uses a separate routing table, so the route can exist in both tables. - // The main routing table takes precedence over the wireguard routing table. - if !strings.Contains(buf.String(), "because it already exists") && runtime.GOOS != "linux" { + if !strings.Contains(buf.String(), "because it already exists") { require.False(t, ok, "route should not exist") } }) } } + +func TestIsSubRange(t *testing.T) { + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var subRangeAddressPrefixes []netip.Prefix + var nonSubRangeAddressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if !p.Addr().IsLoopback() && p.Addr().Is4() && p.Bits() < 32 { + p2 := netip.PrefixFrom(p.Masked().Addr(), p.Bits()+1) + subRangeAddressPrefixes = append(subRangeAddressPrefixes, p2) + nonSubRangeAddressPrefixes = append(nonSubRangeAddressPrefixes, p.Masked()) + } + } + + for _, prefix := range subRangeAddressPrefixes { + isSubRangePrefix, err := isSubRange(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address is sub-range: ", err) + } + if !isSubRangePrefix { + t.Fatalf("address %s should be sub-range of an existing route in the table", prefix) + } + } + + for _, prefix := range nonSubRangeAddressPrefixes { + isSubRangePrefix, err := isSubRange(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address is sub-range: ", err) + } + if isSubRangePrefix { + t.Fatalf("address %s should not be sub-range of an existing route in the table", prefix) + } + } +} + +func TestExistsInRouteTable(t *testing.T) { + addresses, err := net.InterfaceAddrs() + if err != nil { + t.Fatal("shouldn't return error when fetching interface addresses: ", err) + } + + var addressPrefixes []netip.Prefix + for _, address := range addresses { + p := netip.MustParsePrefix(address.String()) + if p.Addr().Is6() { + continue + } + // Windows sometimes has hidden interface link local addrs that don't turn up on any interface + if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() { + continue + } + // Linux loopback 127/8 is in the local table, not in the main table and always takes precedence + if runtime.GOOS == "linux" && p.Addr().IsLoopback() { + continue + } + + addressPrefixes = append(addressPrefixes, p.Masked()) + } + + for _, prefix := range addressPrefixes { + exists, err := existsInRouteTable(prefix) + if err != nil { + t.Fatal("shouldn't return error when checking if address exists in route table: ", err) + } + if !exists { + t.Fatalf("address %s should exist in route table", prefix) + } + } +} + +func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface { + t.Helper() + + peerPrivateKey, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + newNet, err := stdnet.NewNet() + require.NoError(t, err) + + wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) + require.NoError(t, err, "should create testing WireGuard interface") + + err = wgInterface.Create() + require.NoError(t, err, "should create testing WireGuard interface") + + t.Cleanup(func() { + wgInterface.Close() + }) + + return wgInterface +} + +func setupTestEnv(t *testing.T) { + t.Helper() + + setupDummyInterfacesAndRoutes(t) + + wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820) + t.Cleanup(func() { + assert.NoError(t, wgIface.Close()) + }) + + _, _, err := setupRouting(nil, wgIface) + require.NoError(t, err, "setupRouting should not return err") + t.Cleanup(func() { + assert.NoError(t, cleanupRouting()) + }) + + // default route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.0.0.0/8 route exists in main table and vpn table + err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 10.10.0.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // 127.0.10.0/24 more specific route exists in vpn table + err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) + + // unique route in vpn table + err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + require.NoError(t, err, "addVPNRoute should not return err") + t.Cleanup(func() { + err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name()) + assert.NoError(t, err, "removeVPNRoute should not return err") + }) +} + +func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) { + t.Helper() + if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() { + return + } + + prefixGateway, _, err := getNextHop(prefix.Addr()) + require.NoError(t, err, "getNextHop should not return err") + if invert { + assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP") + } else { + assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP") + } +} diff --git a/client/internal/routemanager/systemops_unix_test.go b/client/internal/routemanager/systemops_unix_test.go new file mode 100644 index 000000000..561eaeea4 --- /dev/null +++ b/client/internal/routemanager/systemops_unix_test.go @@ -0,0 +1,234 @@ +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly + +package routemanager + +import ( + "fmt" + "net" + "strings" + "testing" + "time" + + "github.com/gopacket/gopacket" + "github.com/gopacket/gopacket/layers" + "github.com/gopacket/gopacket/pcap" + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +type PacketExpectation struct { + SrcIP net.IP + DstIP net.IP + SrcPort int + DstPort int + UDP bool + TCP bool +} + +type testCase struct { + name string + destination string + expectedInterface string + dialer dialer + expectedPacket PacketExpectation +} + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53), + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedInterface: expectedInternalInt, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53), + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedInterface: expectedExternalInt, + dialer: nbnet.NewDialer(), + expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedInterface: expectedVPNint, + dialer: &net.Dialer{}, + expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53), + }, +} + +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + filter := createBPFFilter(tc.destination) + handle := startPacketCapture(t, tc.expectedInterface, filter) + + sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer) + + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + packet, err := packetSource.NextPacket() + require.NoError(t, err) + + verifyPacket(t, packet, tc.expectedPacket) + }) + } +} + +func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation { + return PacketExpectation{ + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + UDP: true, + } +} + +func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle { + t.Helper() + + inactive, err := pcap.NewInactiveHandle(intf) + require.NoError(t, err, "Failed to create inactive pcap handle") + defer inactive.CleanUp() + + err = inactive.SetSnapLen(1600) + require.NoError(t, err, "Failed to set snap length on inactive handle") + + err = inactive.SetTimeout(time.Second * 10) + require.NoError(t, err, "Failed to set timeout on inactive handle") + + err = inactive.SetImmediateMode(true) + require.NoError(t, err, "Failed to set immediate mode on inactive handle") + + handle, err := inactive.Activate() + require.NoError(t, err, "Failed to activate pcap handle") + t.Cleanup(handle.Close) + + err = handle.SetBPFFilter(filter) + require.NoError(t, err, "Failed to set BPF filter") + + return handle +} + +func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) { + t.Helper() + + if dialer == nil { + dialer = &net.Dialer{} + } + + if sourcePort != 0 { + localUDPAddr := &net.UDPAddr{ + IP: net.IPv4zero, + Port: sourcePort, + } + switch dialer := dialer.(type) { + case *nbnet.Dialer: + dialer.LocalAddr = localUDPAddr + case *net.Dialer: + dialer.LocalAddr = localUDPAddr + default: + t.Fatal("Unsupported dialer type") + } + } + + msg := new(dns.Msg) + msg.Id = dns.Id() + msg.RecursionDesired = true + msg.Question = []dns.Question{ + {Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + + conn, err := dialer.Dial("udp", destination) + require.NoError(t, err, "Failed to dial UDP") + defer conn.Close() + + data, err := msg.Pack() + require.NoError(t, err, "Failed to pack DNS message") + + _, err = conn.Write(data) + if err != nil { + if strings.Contains(err.Error(), "required key not available") { + t.Logf("Ignoring WireGuard key error: %v", err) + return + } + t.Fatalf("Failed to send DNS query: %v", err) + } +} + +func createBPFFilter(destination string) string { + host, port, err := net.SplitHostPort(destination) + if err != nil { + return fmt.Sprintf("udp and dst host %s and dst port %s", host, port) + } + return "udp" +} + +func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) { + t.Helper() + + ipLayer := packet.Layer(layers.LayerTypeIPv4) + require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet") + + ip, ok := ipLayer.(*layers.IPv4) + require.True(t, ok, "Failed to cast to IPv4 layer") + + // Convert both source and destination IP addresses to 16-byte representation + expectedSrcIP := exp.SrcIP.To16() + actualSrcIP := ip.SrcIP.To16() + assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch") + + expectedDstIP := exp.DstIP.To16() + actualDstIP := ip.DstIP.To16() + assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch") + + if exp.UDP { + udpLayer := packet.Layer(layers.LayerTypeUDP) + require.NotNil(t, udpLayer, "Expected UDP layer not found in packet") + + udp, ok := udpLayer.(*layers.UDP) + require.True(t, ok, "Failed to cast to UDP layer") + + assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch") + assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch") + } + + if exp.TCP { + tcpLayer := packet.Layer(layers.LayerTypeTCP) + require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet") + + tcp, ok := tcpLayer.(*layers.TCP) + require.True(t, ok, "Failed to cast to TCP layer") + + assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch") + assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch") + } +} diff --git a/client/internal/routemanager/systemops_windows.go b/client/internal/routemanager/systemops_windows.go index c009ce66b..334ace453 100644 --- a/client/internal/routemanager/systemops_windows.go +++ b/client/internal/routemanager/systemops_windows.go @@ -6,9 +6,14 @@ import ( "fmt" "net" "net/netip" + "os/exec" + "strings" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" + + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/iface" ) type Win32_IP4RouteTable struct { @@ -16,6 +21,16 @@ type Win32_IP4RouteTable struct { Mask string } +var routeManager *RouteManager + +func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) { + return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface) +} + +func cleanupRouting() error { + return cleanupRoutingWithRouteManager(routeManager) +} + func getRoutesFromTable() ([]netip.Prefix, error) { var routes []Win32_IP4RouteTable query := "SELECT Destination, Mask FROM Win32_IP4RouteTable" @@ -48,10 +63,85 @@ func getRoutesFromTable() ([]netip.Prefix, error) { return prefixList, nil } -func addToRouteTableIfNoExists(prefix netip.Prefix, addr string, intf string) error { - return genericAddToRouteTableIfNoExists(prefix, addr, intf) +func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf, intfIdx string) error { + destinationPrefix := prefix.String() + psCmd := "New-NetRoute" + + addressFamily := "IPv4" + if prefix.Addr().Is6() { + addressFamily = "IPv6" + } + + script := fmt.Sprintf( + `%s -AddressFamily "%s" -DestinationPrefix "%s" -Confirm:$False -ErrorAction Stop`, + psCmd, addressFamily, destinationPrefix, + ) + + if intfIdx != "" { + script = fmt.Sprintf( + `%s -InterfaceIndex %s`, script, intfIdx, + ) + } else { + script = fmt.Sprintf( + `%s -InterfaceAlias "%s"`, script, intf, + ) + } + + if nexthop.IsValid() { + script = fmt.Sprintf( + `%s -NextHop "%s"`, script, nexthop, + ) + } + + out, err := exec.Command("powershell", "-Command", script).CombinedOutput() + log.Tracef("PowerShell %s: %s", script, string(out)) + + if err != nil { + return fmt.Errorf("PowerShell add route: %w", err) + } + + return nil } -func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string, intf string) error { - return genericRemoveFromRouteTableIfNonSystem(prefix, addr, intf) +func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"add", prefix.String(), nexthop.Unmap().String()} + + out, err := exec.Command("route", args...).CombinedOutput() + + log.Tracef("route %s: %s", strings.Join(args, " "), out) + if err != nil { + return fmt.Errorf("route add: %w", err) + } + + return nil +} + +func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error { + var intfIdx string + if nexthop.Zone() != "" { + intfIdx = nexthop.Zone() + nexthop.WithZone("") + } + + // Powershell doesn't support adding routes without an interface but allows to add interface by name + if intf != "" || intfIdx != "" { + return addRoutePowershell(prefix, nexthop, intf, intfIdx) + } + return addRouteCmd(prefix, nexthop, intf) +} + +func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error { + args := []string{"delete", prefix.String()} + if nexthop.IsValid() { + nexthop.WithZone("") + args = append(args, nexthop.Unmap().String()) + } + + out, err := exec.Command("route", args...).CombinedOutput() + log.Tracef("route %s: %s", strings.Join(args, " "), out) + + if err != nil { + return fmt.Errorf("remove route: %w", err) + } + return nil } diff --git a/client/internal/routemanager/systemops_windows_test.go b/client/internal/routemanager/systemops_windows_test.go new file mode 100644 index 000000000..a5e03b8d2 --- /dev/null +++ b/client/internal/routemanager/systemops_windows_test.go @@ -0,0 +1,289 @@ +package routemanager + +import ( + "context" + "encoding/json" + "fmt" + "net" + "os/exec" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +var expectedExtInt = "Ethernet1" + +type RouteInfo struct { + NextHop string `json:"nexthop"` + InterfaceAlias string `json:"interfacealias"` + RouteMetric int `json:"routemetric"` +} + +type FindNetRouteOutput struct { + IPAddress string `json:"IPAddress"` + InterfaceIndex int `json:"InterfaceIndex"` + InterfaceAlias string `json:"InterfaceAlias"` + AddressFamily int `json:"AddressFamily"` + NextHop string `json:"NextHop"` + DestinationPrefix string `json:"DestinationPrefix"` +} + +type testCase struct { + name string + destination string + expectedSourceIP string + expectedDestPrefix string + expectedNextHop string + expectedInterface string + dialer dialer +} + +var expectedVPNint = "wgtest0" + +var testCases = []testCase{ + { + name: "To external host without custom dialer via vpn", + destination: "192.0.2.1:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "128.0.0.0/1", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + { + name: "To external host with custom dialer via physical interface", + destination: "192.0.2.1:53", + expectedDestPrefix: "192.0.2.1/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + + { + name: "To duplicate internal route with custom dialer via physical interface", + destination: "10.0.0.2:53", + expectedDestPrefix: "10.0.0.2/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + { + name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence + destination: "10.0.0.2:53", + expectedSourceIP: "10.0.0.1", + expectedDestPrefix: "10.0.0.0/8", + expectedNextHop: "0.0.0.0", + expectedInterface: "Loopback Pseudo-Interface 1", + dialer: &net.Dialer{}, + }, + + { + name: "To unique vpn route with custom dialer via physical interface", + destination: "172.16.0.2:53", + expectedDestPrefix: "172.16.0.2/32", + expectedInterface: expectedExtInt, + dialer: nbnet.NewDialer(), + }, + { + name: "To unique vpn route without custom dialer via vpn", + destination: "172.16.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "172.16.0.0/12", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route without custom dialer via vpn interface", + destination: "10.10.0.2:53", + expectedSourceIP: "100.64.0.1", + expectedDestPrefix: "10.10.0.0/24", + expectedNextHop: "0.0.0.0", + expectedInterface: "wgtest0", + dialer: &net.Dialer{}, + }, + + { + name: "To more specific route (local) without custom dialer via physical interface", + destination: "127.0.10.2:53", + expectedSourceIP: "10.0.0.1", + expectedDestPrefix: "127.0.0.0/8", + expectedNextHop: "0.0.0.0", + expectedInterface: "Loopback Pseudo-Interface 1", + dialer: &net.Dialer{}, + }, +} + +func TestRouting(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + setupTestEnv(t) + + route, err := fetchOriginalGateway() + require.NoError(t, err, "Failed to fetch original gateway") + ip, err := fetchInterfaceIP(route.InterfaceAlias) + require.NoError(t, err, "Failed to fetch interface IP") + + output := testRoute(t, tc.destination, tc.dialer) + if tc.expectedInterface == expectedExtInt { + verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias) + } else { + verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface) + } + }) + } +} + +// fetchInterfaceIP fetches the IPv4 address of the specified interface. +func fetchInterfaceIP(interfaceAlias string) (string, error) { + script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias) + out, err := exec.Command("powershell", "-Command", script).Output() + if err != nil { + return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err) + } + + ip := strings.TrimSpace(string(out)) + return ip, nil +} + +func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + conn, err := dialer.DialContext(ctx, "udp", destination) + require.NoError(t, err, "Failed to dial destination") + defer func() { + err := conn.Close() + assert.NoError(t, err, "Failed to close connection") + }() + + host, _, err := net.SplitHostPort(destination) + require.NoError(t, err) + + script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host) + + out, err := exec.Command("powershell", "-Command", script).Output() + require.NoError(t, err, "Failed to execute Find-NetRoute") + + var outputs []FindNetRouteOutput + err = json.Unmarshal(out, &outputs) + require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute") + + require.Greater(t, len(outputs), 0, "No route found for destination") + combinedOutput := combineOutputs(outputs) + + return combinedOutput +} + +func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string { + t.Helper() + + ip, ipNet, err := net.ParseCIDR(ipAddressCIDR) + require.NoError(t, err) + subnetMaskSize, _ := ipNet.Mask.Size() + script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to assign IP address to loopback adapter") + + // Wait for the IP address to be applied + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + err = waitForIPAddress(ctx, interfaceName, ip.String()) + require.NoError(t, err, "IP address not applied within timeout") + + t.Cleanup(func() { + script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String()) + _, err = exec.Command("powershell", "-Command", script).CombinedOutput() + require.NoError(t, err, "Failed to remove IP address from loopback adapter") + }) + + return interfaceName +} + +func fetchOriginalGateway() (*RouteInfo, error) { + cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json") + output, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err) + } + + var routeInfo RouteInfo + err = json.Unmarshal(output, &routeInfo) + if err != nil { + return nil, fmt.Errorf("failed to parse JSON output: %w", err) + } + + return &routeInfo, nil +} + +func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) { + t.Helper() + + assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch") + assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch") + assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch") + assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch") +} + +func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput() + if err != nil { + return err + } + + ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n") + for _, ip := range ipAddresses { + if strings.TrimSpace(ip) == expectedIPAddress { + return nil + } + } + } + } +} + +func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput { + var combined FindNetRouteOutput + + for _, output := range outputs { + if output.IPAddress != "" { + combined.IPAddress = output.IPAddress + } + if output.InterfaceIndex != 0 { + combined.InterfaceIndex = output.InterfaceIndex + } + if output.InterfaceAlias != "" { + combined.InterfaceAlias = output.InterfaceAlias + } + if output.AddressFamily != 0 { + combined.AddressFamily = output.AddressFamily + } + if output.NextHop != "" { + combined.NextHop = output.NextHop + } + if output.DestinationPrefix != "" { + combined.DestinationPrefix = output.DestinationPrefix + } + } + + return &combined +} + +func setupDummyInterfacesAndRoutes(t *testing.T) { + t.Helper() + + createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8") +} diff --git a/client/internal/wgproxy/portlookup.go b/client/internal/wgproxy/portlookup.go index 6ede4b83f..6f3d33487 100644 --- a/client/internal/wgproxy/portlookup.go +++ b/client/internal/wgproxy/portlookup.go @@ -1,10 +1,8 @@ package wgproxy import ( - "context" "fmt" - - nbnet "github.com/netbirdio/netbird/util/net" + "net" ) const ( @@ -25,7 +23,7 @@ func (pl portLookup) searchFreePort() (int, error) { } func (pl portLookup) tryToBind(port int) error { - l, err := nbnet.NewListener().ListenPacket(context.Background(), "udp", fmt.Sprintf(":%d", port)) + l, err := net.ListenPacket("udp", fmt.Sprintf(":%d", port)) if err != nil { return err } diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/proxy_ebpf.go index b91cd7b43..2235c5d2b 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/proxy_ebpf.go @@ -12,6 +12,7 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/ebpf" @@ -29,7 +30,7 @@ type WGEBPFProxy struct { turnConnMutex sync.Mutex rawConn net.PacketConn - conn *net.UDPConn + conn transport.UDPConn } // NewWGEBPFProxy create new WGEBPFProxy instance @@ -67,7 +68,7 @@ func (p *WGEBPFProxy) Listen() error { IP: net.ParseIP("127.0.0.1"), } - p.conn, err = nbnet.ListenUDP("udp", &addr) + conn, err := nbnet.ListenUDP("udp", &addr) if err != nil { cErr := p.Free() if cErr != nil { @@ -75,6 +76,7 @@ func (p *WGEBPFProxy) Listen() error { } return err } + p.conn = conn go p.proxyToRemote() log.Infof("local wg proxy listening on: %d", wgPorxyPort) diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 81998b115..4b8502268 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -44,17 +44,18 @@ type LoginRequest struct { // cleanNATExternalIPs clean map list of external IPs. // This is needed because the generated code // omits initialized empty slices due to omitempty tags - CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` - CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` - IsLinuxDesktopClient bool `protobuf:"varint,8,opt,name=isLinuxDesktopClient,proto3" json:"isLinuxDesktopClient,omitempty"` - Hostname string `protobuf:"bytes,9,opt,name=hostname,proto3" json:"hostname,omitempty"` - RosenpassEnabled *bool `protobuf:"varint,10,opt,name=rosenpassEnabled,proto3,oneof" json:"rosenpassEnabled,omitempty"` - InterfaceName *string `protobuf:"bytes,11,opt,name=interfaceName,proto3,oneof" json:"interfaceName,omitempty"` - WireguardPort *int64 `protobuf:"varint,12,opt,name=wireguardPort,proto3,oneof" json:"wireguardPort,omitempty"` - OptionalPreSharedKey *string `protobuf:"bytes,13,opt,name=optionalPreSharedKey,proto3,oneof" json:"optionalPreSharedKey,omitempty"` - DisableAutoConnect *bool `protobuf:"varint,14,opt,name=disableAutoConnect,proto3,oneof" json:"disableAutoConnect,omitempty"` - ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"` - RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"` + CleanNATExternalIPs bool `protobuf:"varint,6,opt,name=cleanNATExternalIPs,proto3" json:"cleanNATExternalIPs,omitempty"` + CustomDNSAddress []byte `protobuf:"bytes,7,opt,name=customDNSAddress,proto3" json:"customDNSAddress,omitempty"` + IsLinuxDesktopClient bool `protobuf:"varint,8,opt,name=isLinuxDesktopClient,proto3" json:"isLinuxDesktopClient,omitempty"` + Hostname string `protobuf:"bytes,9,opt,name=hostname,proto3" json:"hostname,omitempty"` + RosenpassEnabled *bool `protobuf:"varint,10,opt,name=rosenpassEnabled,proto3,oneof" json:"rosenpassEnabled,omitempty"` + InterfaceName *string `protobuf:"bytes,11,opt,name=interfaceName,proto3,oneof" json:"interfaceName,omitempty"` + WireguardPort *int64 `protobuf:"varint,12,opt,name=wireguardPort,proto3,oneof" json:"wireguardPort,omitempty"` + OptionalPreSharedKey *string `protobuf:"bytes,13,opt,name=optionalPreSharedKey,proto3,oneof" json:"optionalPreSharedKey,omitempty"` + DisableAutoConnect *bool `protobuf:"varint,14,opt,name=disableAutoConnect,proto3,oneof" json:"disableAutoConnect,omitempty"` + ServerSSHAllowed *bool `protobuf:"varint,15,opt,name=serverSSHAllowed,proto3,oneof" json:"serverSSHAllowed,omitempty"` + RosenpassPermissive *bool `protobuf:"varint,16,opt,name=rosenpassPermissive,proto3,oneof" json:"rosenpassPermissive,omitempty"` + ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` } func (x *LoginRequest) Reset() { @@ -202,6 +203,13 @@ func (x *LoginRequest) GetRosenpassPermissive() bool { return false } +func (x *LoginRequest) GetExtraIFaceBlacklist() []string { + if x != nil { + return x.ExtraIFaceBlacklist + } + return nil +} + type LoginResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -1385,7 +1393,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xdd, 0x06, 0x0a, 0x0c, 0x4c, 0x6f, + 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x8f, 0x07, 0x0a, 0x0c, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, @@ -1430,192 +1438,195 @@ var file_daemon_proto_rawDesc = []byte{ 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x10, 0x20, 0x01, 0x28, 0x08, 0x48, 0x06, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, - 0x88, 0x01, 0x01, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, - 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, 0x6e, 0x74, - 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, - 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, 0x0a, 0x15, - 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, - 0x65, 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, - 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, 0x0a, 0x11, - 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, - 0x64, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, - 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xb5, 0x01, 0x0a, 0x0d, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6e, - 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x28, 0x0a, - 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, - 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, - 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, - 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, - 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x74, - 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, - 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, - 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, - 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, 0x70, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, - 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x32, 0x0a, - 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, 0x6c, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, - 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, 0x77, 0x6e, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb3, 0x01, 0x0a, 0x11, 0x47, - 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, - 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, 0x6c, - 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, 0x6c, 0x65, - 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, - 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, 0x52, 0x4c, - 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, - 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, - 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, - 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, - 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, - 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, - 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, - 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, - 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, + 0x88, 0x01, 0x01, 0x12, 0x30, 0x0a, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, + 0x65, 0x42, 0x6c, 0x61, 0x63, 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x11, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x13, 0x65, 0x78, 0x74, 0x72, 0x61, 0x49, 0x46, 0x61, 0x63, 0x65, 0x42, 0x6c, 0x61, 0x63, + 0x6b, 0x6c, 0x69, 0x73, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, + 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, + 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, + 0x5f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, + 0x0a, 0x15, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, + 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, + 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, + 0x0a, 0x11, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, + 0x77, 0x65, 0x64, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, + 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xb5, 0x01, 0x0a, 0x0d, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, + 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, + 0x28, 0x0a, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, + 0x52, 0x49, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, 0x72, + 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, + 0x6c, 0x65, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, + 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, 0x6c, + 0x65, 0x74, 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, + 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, + 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, + 0x6d, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, 0x70, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, + 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, + 0x32, 0x0a, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, 0x6c, + 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, + 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, 0x77, + 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb3, 0x01, 0x0a, + 0x11, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x55, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, 0x46, + 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, 0x69, + 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, + 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, + 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, + 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, 0x55, + 0x52, 0x4c, 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, + 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, + 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, + 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, + 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, + 0x63, 0x74, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, + 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, + 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, + 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, - 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, - 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, - 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, - 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, - 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, - 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, + 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, + 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, - 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x0b, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, - 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, - 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c, 0x61, - 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, - 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x18, - 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12, 0x18, - 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, - 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x10, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x07, - 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, - 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, - 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, - 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, - 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, + 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, + 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, + 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, + 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, + 0x78, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, + 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, + 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, - 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, - 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, - 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, - 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, - 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, - 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, - 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, - 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, - 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, - 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, - 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, - 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, - 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, - 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, - 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, - 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, 0x0a, 0x0d, - 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, - 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, - 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, - 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, - 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, - 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, - 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, + 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, + 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, + 0x6e, 0x63, 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, + 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, + 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, + 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, + 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, + 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, + 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, + 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, + 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, + 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, + 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, + 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, + 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, + 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, + 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, + 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, + 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, + 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, + 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, + 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, + 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, + 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x32, 0xf7, 0x02, + 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, + 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, + 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, + 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, + 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 8f9148d68..5f8878a11 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -70,6 +70,8 @@ message LoginRequest { optional bool serverSSHAllowed = 15; optional bool rosenpassPermissive = 16; + + repeated string extraIFaceBlacklist = 17; } message LoginResponse { diff --git a/client/server/server.go b/client/server/server.go index 481ef0f7c..d1d9dbda4 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -152,7 +152,8 @@ func (s *Server) Start() error { // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, - mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe) { + mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe, +) { backOff := getConnectWithBackoff(ctx) retryStarted := false @@ -351,6 +352,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.latestConfigInput.WireguardPort = &port } + if len(msg.ExtraIFaceBlacklist) > 0 { + inputConfig.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + s.latestConfigInput.ExtraIFaceBlackList = msg.ExtraIFaceBlacklist + } + s.mutex.Unlock() if msg.OptionalPreSharedKey != nil { diff --git a/client/server/server_test.go b/client/server/server_test.go index 7f8310c90..4e4a09145 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "context" + "github.com/netbirdio/management-integrations/integrations" "net" "testing" "time" @@ -114,7 +115,8 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { return nil, "", err } diff --git a/client/system/detect_cloud/detect.go b/client/system/detect_cloud/detect.go index 3bbff4345..8a8de763e 100644 --- a/client/system/detect_cloud/detect.go +++ b/client/system/detect_cloud/detect.go @@ -25,8 +25,6 @@ func Detect(ctx context.Context) string { detectDigitalOcean, detectGCP, detectOracle, - detectIBMCloud, - detectSoftlayer, detectVultr, } diff --git a/client/system/detect_cloud/gcp.go b/client/system/detect_cloud/gcp.go index c673f8937..a24c38c0c 100644 --- a/client/system/detect_cloud/gcp.go +++ b/client/system/detect_cloud/gcp.go @@ -6,7 +6,7 @@ import ( ) func detectGCP(ctx context.Context) string { - req, err := http.NewRequestWithContext(ctx, "GET", "http://metadata.google.internal", nil) + req, err := http.NewRequestWithContext(ctx, "GET", "http://169.254.169.254", nil) if err != nil { return "" } diff --git a/client/system/detect_cloud/ibmcloud.go b/client/system/detect_cloud/ibmcloud.go deleted file mode 100644 index 07de6a2ee..000000000 --- a/client/system/detect_cloud/ibmcloud.go +++ /dev/null @@ -1,54 +0,0 @@ -package detect_cloud - -import ( - "context" - "net/http" -) - -func detectIBMCloud(ctx context.Context) string { - v1ResultChan := make(chan bool, 1) - v2ResultChan := make(chan bool, 1) - - go func() { - v1ResultChan <- detectIBMSecure(ctx) - }() - - go func() { - v2ResultChan <- detectIBM(ctx) - }() - - v1Result, v2Result := <-v1ResultChan, <-v2ResultChan - - if v1Result || v2Result { - return "IBM Cloud" - } - return "" -} - -func detectIBMSecure(ctx context.Context) bool { - req, err := http.NewRequestWithContext(ctx, "PUT", "https://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil) - if err != nil { - return false - } - - resp, err := hc.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - return resp.StatusCode == http.StatusOK -} - -func detectIBM(ctx context.Context) bool { - req, err := http.NewRequestWithContext(ctx, "PUT", "http://api.metadata.cloud.ibm.com/instance_identity/v1/token", nil) - if err != nil { - return false - } - - resp, err := hc.Do(req) - if err != nil { - return false - } - defer resp.Body.Close() - return resp.StatusCode == http.StatusOK -} diff --git a/client/system/detect_cloud/softlayer.go b/client/system/detect_cloud/softlayer.go deleted file mode 100644 index a09b522c4..000000000 --- a/client/system/detect_cloud/softlayer.go +++ /dev/null @@ -1,25 +0,0 @@ -package detect_cloud - -import ( - "context" - "net/http" -) - -func detectSoftlayer(ctx context.Context) string { - req, err := http.NewRequestWithContext(ctx, "GET", "https://api.service.softlayer.com/rest/v3/SoftLayer_Resource_Metadata/UserMetadata.txt", nil) - if err != nil { - return "" - } - - resp, err := hc.Do(req) - if err != nil { - return "" - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusOK { - // Since SoftLayer was acquired by IBM, we should return "IBM Cloud" - return "IBM Cloud" - } - return "" -} diff --git a/client/system/info_linux.go b/client/system/info_linux.go index ca3be9d1c..652bc1115 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -120,5 +120,5 @@ func _getReleaseInfo() string { func sysInfo() (serialNumber string, productName string, manufacturer string) { var si sysinfo.SysInfo si.GetSysInfo() - return si.Product.Version, si.Product.Name, si.Product.Vendor + return si.Chassis.Serial, si.Product.Name, si.Product.Vendor } diff --git a/go.mod b/go.mod index 67ec9c42e..29a1570c8 100644 --- a/go.mod +++ b/go.mod @@ -46,21 +46,21 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.9 github.com/google/gopacket v1.1.19 + github.com/google/martian/v3 v3.0.0 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/gopacket/gopacket v1.1.1 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 - github.com/libp2p/go-netroute v0.2.0 + github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.5 github.com/mattn/go-sqlite3 v1.14.19 github.com/mdlayher/socket v0.4.1 github.com/miekg/dns v1.1.43 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 + github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/go.sum b/go.sum index c36b8aff3..b488a42a4 100644 --- a/go.sum +++ b/go.sum @@ -255,6 +255,7 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/martian/v3 v3.0.0 h1:pMen7vLs8nvgEYhywH3KDWJIJTeEr2ULsVWHWYHQyBs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/nftables v0.0.0-20220808154552-2eca00135732 h1:csc7dT82JiSLvq4aMyQMIQDL7986NH6Wxf/QrvOj55A= github.com/google/nftables v0.0.0-20220808154552-2eca00135732/go.mod h1:b97ulCCFipUC+kSin+zygkvUVpx0vyIAwxXFdY3PlNc= @@ -344,8 +345,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= -github.com/libp2p/go-netroute v0.2.0 h1:0FpsbsvuSnAhXFnCY0VLFbJOzaK0VnP0r1QT/o4nWRE= -github.com/libp2p/go-netroute v0.2.0/go.mod h1:Vio7LTzZ+6hoT4CMZi5/6CpY3Snzh2vgZhWgxMNwlQI= +github.com/libp2p/go-netroute v0.2.1 h1:V8kVrpD8GK0Riv15/7VN6RbUQ3URNZVosw7H2v9tksU= +github.com/libp2p/go-netroute v0.2.1/go.mod h1:hraioZr0fhBjG0ZRXJJ6Zj2IVEVNx6tDTFQfSmcq7mQ= github.com/lucor/goinfo v0.0.0-20210802170112-c078a2b0f08b/go.mod h1:PRq09yoB+Q2OJReAmwzKivcYyremnibWGbK7WfftHzc= github.com/magiconair/properties v1.8.5 h1:b6kJs+EmPFMYGkow9GiUyCyOvIwYetYJ3fSaWak/Gls= github.com/magiconair/properties v1.8.5/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= @@ -382,10 +383,8 @@ github.com/nadoo/ipset v0.5.0 h1:5GJUAuZ7ITQQQGne5J96AmFjRtI8Avlbk6CabzYWVUc= github.com/nadoo/ipset v0.5.0/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552 h1:yzcQKizAK9YufCHMMCIsr467Dw/OU/4xyHbWizGb1E4= -github.com/netbirdio/management-integrations/additions v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:31FhBNvQ+riHEIu6LSTmqr8IeuSIsGfQffqV4LFmbwA= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552 h1:OFlzVZtkXCoJsfDKrMigFpuad8ZXTm8epq6x27K0irA= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240212121739-8ea8c89a4552/go.mod h1:B0nMS3es77gOvPYhc0K91fAzTkQLi/jRq5TffUN3klM= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98 h1:i6AtenTLu/CqhTmj0g1K/GWkkpMJMhQM6Vjs46x25nA= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240326083846-3682438fca98/go.mod h1:kxks50DrZnhW+oRTdHOkVOJbcTcyo766am8RBugo+Yc= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 h1:xbWM9BU6mwZZLHxEjxIX/V8Hv3HurQt4mReIE4mY4DM= @@ -660,7 +659,6 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= @@ -747,7 +745,6 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210426080607-c94f62235c83/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 47cadb93c..747eebd53 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -58,6 +58,7 @@ services: command: [ "--port", "443", "--log-file", "console", + "--log-level", "info", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS", "--single-account-mode-domain=$NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN", "--dns-domain=$NETBIRD_MGMT_DNS_DOMAIN" diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index fd194a042..d3ae6529a 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -2,7 +2,7 @@ version: "3" services: #UI dashboard dashboard: - image: wiretrustee/dashboard:$NETBIRD_DASHBOARD_TAG + image: netbirdio/dashboard:$NETBIRD_DASHBOARD_TAG restart: unless-stopped #ports: # - 80:80 diff --git a/management/client/client_test.go b/management/client/client_test.go index f30ae0cfd..30f91c73b 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -3,6 +3,7 @@ package client import ( "context" "net" + "os" "path/filepath" "sync" "testing" @@ -15,6 +16,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" mgmt "github.com/netbirdio/netbird/management/server" @@ -30,6 +32,12 @@ import ( const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" +func TestMain(m *testing.M) { + _ = util.InitLog("debug", "console") + code := m.Run() + os.Exit(code) +} + func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Helper() level, _ := log.ParseLevel("debug") @@ -60,7 +68,8 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { peersUpdateManager := mgmt.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} - accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false) + ia, _ := integrations.NewIntegratedValidator(eventStore) + accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index e8bcdc97d..23d9c195c 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/keepalive" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" @@ -172,8 +173,12 @@ var ( log.Infof("geo location service has been initialized from %s", config.Datadir) } + integratedPeerValidator, err := integrations.NewIntegratedValidator(eventStore) + if err != nil { + return fmt.Errorf("failed to initialize integrated peer validator: %v", err) + } accountManager, err := server.BuildManager(store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain, - dnsDomain, eventStore, geo, userDeleteFromIDPEnabled) + dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator) if err != nil { return fmt.Errorf("failed to build default manager: %v", err) } @@ -323,6 +328,7 @@ var ( SetupCloseHandler() <-stopCh + integratedPeerValidator.Stop() if geo != nil { _ = geo.Stop() } diff --git a/management/server/account.go b/management/server/account.go index 8b326d93a..20bd15ad6 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -21,14 +21,15 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/management-integrations/additions" - "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integrated_validator" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" @@ -85,11 +86,12 @@ type AccountManager interface { GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) - GetGroup(accountId, groupID string) (*Group, error) - GetGroupByName(groupName, accountID string) (*Group, error) - SaveGroup(accountID, userID string, group *Group) error + GetGroup(accountId, groupID, userID string) (*nbgroup.Group, error) + GetAllGroups(accountID, userID string) ([]*nbgroup.Group, error) + GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) + SaveGroup(accountID, userID string, group *nbgroup.Group) error DeleteGroup(accountId, userId, groupID string) error - ListGroups(accountId string) ([]*Group, error) + ListGroups(accountId string) ([]*nbgroup.Group, error) GroupAddPeer(accountId, groupID, peerID string) error GroupDeletePeer(accountId, groupID, peerID string) error GetPolicy(accountID, policyID, userID string) (*Policy, error) @@ -123,6 +125,9 @@ type AccountManager interface { DeletePostureChecks(accountID, postureChecksID, userID string) error ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager + UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error + GroupValidation(accountId string, groups []string) (bool, error) + GetValidatedPeers(account *Account) (map[string]struct{}, error) } type DefaultAccountManager struct { @@ -151,6 +156,8 @@ type DefaultAccountManager struct { // userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account userDeleteFromIDPEnabled bool + + integratedPeerValidator integrated_validator.IntegratedValidator } // Settings represents Account settings structure that can be modified via API and Dashboard @@ -162,6 +169,9 @@ type Settings struct { // Applies to all peers that have Peer.LoginExpirationEnabled set to true. PeerLoginExpiration time.Duration + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements + RegularUsersViewBlocked bool + // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer GroupsPropagationEnabled bool @@ -188,6 +198,7 @@ func (s *Settings) Copy() *Settings { JWTGroupsClaimName: s.JWTGroupsClaimName, GroupsPropagationEnabled: s.GroupsPropagationEnabled, JWTAllowGroups: s.JWTAllowGroups, + RegularUsersViewBlocked: s.RegularUsersViewBlocked, } if s.Extra != nil { settings.Extra = s.Extra.Copy() @@ -213,8 +224,8 @@ type Account struct { PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` Users map[string]*User `gorm:"-"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` - Groups map[string]*Group `gorm:"-"` - GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + Groups map[string]*nbgroup.Group `gorm:"-"` + GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Routes map[string]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` @@ -226,6 +237,10 @@ type Account struct { Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } +type UserPermissions struct { + DashboardView string `json:"dashboard_view"` +} + type UserInfo struct { ID string `json:"id"` Email string `json:"email"` @@ -238,7 +253,8 @@ type UserInfo struct { NonDeletable bool `json:"non_deletable"` LastLogin time.Time `json:"last_login"` Issued string `json:"issued"` - IntegrationReference IntegrationReference `json:"-"` + IntegrationReference integration_reference.IntegrationReference `json:"-"` + Permissions UserPermissions `json:"permissions"` } // getRoutesToSync returns the enabled routes for the peer ID and the routes @@ -262,7 +278,7 @@ func (a *Account) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer) []*rou return routes } -// filterRoutesByHAMembership filters and returns a list of routes that don't share the same HA route membership +// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { var filteredRoutes []*route.Route for _, r := range routes { @@ -362,25 +378,26 @@ func (a *Account) GetRoutesByPrefix(prefix netip.Prefix) []*route.Route { } // GetGroup returns a group by ID if exists, nil otherwise -func (a *Account) GetGroup(groupID string) *Group { +func (a *Account) GetGroup(groupID string) *nbgroup.Group { return a.Groups[groupID] } // GetPeerNetworkMap returns a group by ID if exists, nil otherwise -func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string) *NetworkMap { +func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ Network: a.Network.Copy(), } } - validatedPeers := additions.ValidatePeers([]*nbpeer.Peer{peer}) - if len(validatedPeers) == 0 { + + if _, ok := validatedPeersMap[peerID]; !ok { return &NetworkMap{ Network: a.Network.Copy(), } } - aclPeers, firewallRules := a.getPeerConnectionResources(peerID) + + aclPeers, firewallRules := a.getPeerConnectionResources(peerID, validatedPeersMap) // exclude expired peers var peersToConnect []*nbpeer.Peer var expiredPeers []*nbpeer.Peer @@ -554,7 +571,7 @@ func (a *Account) FindUser(userID string) (*User, error) { } // FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. -func (a *Account) FindGroupByName(groupName string) (*Group, error) { +func (a *Account) FindGroupByName(groupName string) (*nbgroup.Group, error) { for _, group := range a.Groups { if group.Name == groupName { return group, nil @@ -573,6 +590,20 @@ func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { return key, nil } +// GetPeerGroupsList return with the list of groups ID. +func (a *Account) GetPeerGroupsList(peerID string) []string { + var grps []string + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + grps = append(grps, groupID) + break + } + } + } + return grps +} + func (a *Account) getUserGroups(userID string) ([]string, error) { user, err := a.FindUser(userID) if err != nil { @@ -650,7 +681,7 @@ func (a *Account) Copy() *Account { setupKeys[id] = key.Copy() } - groups := map[string]*Group{} + groups := map[string]*nbgroup.Group{} for id, group := range a.Groups { groups[id] = group.Copy() } @@ -703,7 +734,7 @@ func (a *Account) Copy() *Account { } } -func (a *Account) GetGroupAll() (*Group, error) { +func (a *Account) GetGroupAll() (*nbgroup.Group, error) { for _, g := range a.Groups { if g.Name == "All" { return g, nil @@ -724,7 +755,7 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { return false } - existedGroupsByName := make(map[string]*Group) + existedGroupsByName := make(map[string]*nbgroup.Group) for _, group := range a.Groups { existedGroupsByName[group.Name] = group } @@ -733,7 +764,7 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { removed := 0 jwtAutoGroups := make(map[string]struct{}) for i, id := range user.AutoGroups { - if group, ok := a.Groups[id]; ok && group.Issued == GroupIssuedJWT { + if group, ok := a.Groups[id]; ok && group.Issued == nbgroup.GroupIssuedJWT { jwtAutoGroups[group.Name] = struct{}{} user.AutoGroups = append(user.AutoGroups[:i-removed], user.AutoGroups[i-removed+1:]...) removed++ @@ -746,15 +777,15 @@ func (a *Account) SetJWTGroups(userID string, groupsNames []string) bool { for _, name := range groupsNames { group, ok := existedGroupsByName[name] if !ok { - group = &Group{ + group = &nbgroup.Group{ ID: xid.New().String(), Name: name, - Issued: GroupIssuedJWT, + Issued: nbgroup.GroupIssuedJWT, } a.Groups[group.ID] = group } // only JWT groups will be synced - if group.Issued == GroupIssuedJWT { + if group.Issued == nbgroup.GroupIssuedJWT { user.AutoGroups = append(user.AutoGroups, group.ID) if _, ok := jwtAutoGroups[name]; !ok { modified = true @@ -827,6 +858,7 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation, userDeleteFromIDPEnabled bool, + integratedPeerValidator integrated_validator.IntegratedValidator, ) (*DefaultAccountManager, error) { am := &DefaultAccountManager{ Store: store, @@ -840,6 +872,7 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage eventStore: eventStore, peerLoginExpiry: NewDefaultScheduler(), userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, + integratedPeerValidator: integratedPeerValidator, } allAccounts := store.GetAllAccounts() // enable single account mode only if configured by user and number of existing accounts is not grater than 1 @@ -896,6 +929,8 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage }() } + am.integratedPeerValidator.SetPeerInvalidationListener(am.onPeersInvalidated) + return am, nil } @@ -938,7 +973,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") } - err = additions.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID, am.eventStore) + err = am.integratedPeerValidator.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) if err != nil { return nil, err } @@ -1588,7 +1623,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat // We override incoming domain claims to group users under a single account. claims.Domain = am.singleAccountModeDomain claims.DomainCategory = PrivateCategory - log.Infof("overriding JWT Domain and DomainCategory claims since single account mode is enabled") + log.Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } newAcc, err := am.getAccountWithAuthorizationClaims(claims) @@ -1813,18 +1848,27 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(claims jwtclaims.Aut return nil } +func (am *DefaultAccountManager) onPeersInvalidated(accountID string) { + updatedAccount, err := am.Store.GetAccount(accountID) + if err != nil { + log.Errorf("failed to get account %s: %v", accountID, err) + return + } + am.updateAccountPeers(updatedAccount) +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { - allGroup := &Group{ + allGroup := &nbgroup.Group{ ID: xid.New().String(), Name: "All", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) } - account.Groups = map[string]*Group{allGroup.ID: allGroup} + account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} id := xid.New().String() @@ -1885,6 +1929,7 @@ func newAccountWithId(accountID, userID, domain string) *Account { PeerLoginExpirationEnabled: true, PeerLoginExpiration: DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, }, } diff --git a/management/server/account/account.go b/management/server/account/account.go index b8b71a6de..40f032fbe 100644 --- a/management/server/account/account.go +++ b/management/server/account/account.go @@ -3,11 +3,17 @@ package account type ExtraSettings struct { // PeerApprovalEnabled enables or disables the need for peers bo be approved by an administrator PeerApprovalEnabled bool + + // IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations + IntegratedValidatorGroups []string `gorm:"serializer:json"` } // Copy copies the ExtraSettings struct func (e *ExtraSettings) Copy() *ExtraSettings { + var cpGroup []string + return &ExtraSettings{ - PeerApprovalEnabled: e.PeerApprovalEnabled, + PeerApprovalEnabled: e.PeerApprovalEnabled, + IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...), } } diff --git a/management/server/account_test.go b/management/server/account_test.go index 2b0c44196..a0eff239b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -12,20 +12,57 @@ import ( "time" "github.com/golang-jwt/jwt" - - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/route" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/jwtclaims" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/route" ) +type MocIntegratedValidator struct { +} + +func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { + return update, nil +} +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for _, peer := range peers { + validatedPeers[peer.ID] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) { + return false, false +} + +func (MocIntegratedValidator) PeerDeleted(_, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + +} + +func (MocIntegratedValidator) Stop() { +} + func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { t.Helper() peer := &nbpeer.Peer{ @@ -367,7 +404,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { account.Groups[all.ID].Peers = append(account.Groups[all.ID].Peers, peer.ID) } - networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io") + validatedPeers := map[string]struct{}{} + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + + networkMap := account.GetPeerNetworkMap(testCase.peerID, "netbird.io", validatedPeers) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -667,7 +709,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "get account by token failed") require.Len(t, account.Groups, 3, "groups should be added to the account") - groupsByNames := map[string]*Group{} + groupsByNames := map[string]*group.Group{} for _, g := range account.Groups { groupsByNames[g.Name] = g } @@ -675,12 +717,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { g1, ok := groupsByNames["group1"] require.True(t, ok, "group1 should be added to the account") require.Equal(t, g1.Name, "group1", "group1 name should match") - require.Equal(t, g1.Issued, GroupIssuedJWT, "group1 issued should match") + require.Equal(t, g1.Issued, group.GroupIssuedJWT, "group1 issued should match") g2, ok := groupsByNames["group2"] require.True(t, ok, "group2 should be added to the account") require.Equal(t, g2.Name, "group2", "group2 name should match") - require.Equal(t, g2.Issued, GroupIssuedJWT, "group2 issued should match") + require.Equal(t, g2.Issued, group.GroupIssuedJWT, "group2 issued should match") }) } @@ -800,7 +842,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { t.Fatalf("expected to create an account for a user %s", userId) } - if account.Domain != domain { + if account != nil && account.Domain != domain { t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) } @@ -815,7 +857,7 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { t.Fatalf("expected to get an account for a user %s", userId) } - if account.Domain != domain { + if account != nil && account.Domain != domain { t.Errorf("updating domain. expected %s got %s", domain, account.Domain) } } @@ -835,13 +877,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { } if account == nil { t.Fatalf("expected to create an account for a user %s", userId) + return } - accountId := account.Id - - _, err = manager.GetAccountByUserOrAccountID("", accountId, "") + _, err = manager.GetAccountByUserOrAccountID("", account.Id, "") if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId) + t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) } _, err = manager.GetAccountByUserOrAccountID("", "", "") @@ -1124,7 +1165,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(peer1.ID) defer manager.peersUpdateManager.CloseChannel(peer1.ID) - group := Group{ + group := group.Group{ ID: "group-id", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -1417,7 +1458,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { Peers: map[string]*nbpeer.Peer{ "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, }, - Groups: map[string]*Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, + Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Routes: map[string]*route.Route{ "route-1": { ID: "route-1", @@ -1518,7 +1559,7 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Groups: map[string]*Group{ + Groups: map[string]*group.Group{ "group1": { ID: "group1", Peers: []string{"peer1"}, @@ -2112,8 +2153,8 @@ func TestAccount_SetJWTGroups(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, Settings: &Settings{GroupsPropagationEnabled: true}, Users: map[string]*User{ @@ -2160,10 +2201,10 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{}}, - "group2": {ID: "group2", Name: "group2", Issued: GroupIssuedAPI, Peers: []string{}}, - "group3": {ID: "group3", Name: "group3", Issued: GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, + "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}}, + "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}}, }, Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } @@ -2196,10 +2237,10 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*Group{ - "group1": {ID: "group1", Name: "group1", Issued: GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, - "group2": {ID: "group2", Name: "group2", Issued: GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, - "group3": {ID: "group3", Name: "group3", Issued: GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, + Groups: map[string]*group.Group{ + "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, + "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, + "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, }, Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } @@ -2223,7 +2264,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}) } func createStore(t *testing.T) (Store, error) { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index aac35308c..18f942e68 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -8,6 +8,7 @@ import ( "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) @@ -193,7 +194,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}) } func createDNSStore(t *testing.T) (Store, error) { @@ -278,13 +279,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro return nil, err } - newGroup1 := &Group{ + newGroup1 := &group.Group{ ID: dnsGroup1ID, Peers: []string{peer1.ID}, Name: dnsGroup1ID, } - newGroup2 := &Group{ + newGroup2 := &group.Group{ ID: dnsGroup2ID, Name: dnsGroup2ID, } diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 9d70a05d1..4fffa024d 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -165,7 +165,7 @@ func (e *EphemeralManager) cleanup() { log.Debugf("delete ephemeral peer: %s", id) err := e.accountManager.DeletePeer(p.account.Id, id, activity.SystemInitiator) if err != nil { - log.Tracef("failed to delete ephemeral peer: %s", err) + log.Errorf("failed to delete ephemeral peer: %s", err) } } } diff --git a/management/server/file_store.go b/management/server/file_store.go index 0228285cb..2de852bee 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -10,6 +10,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" @@ -170,7 +171,7 @@ func restore(file string) (*FileStore, error) { // Set API as issuer for groups which has not this field for _, group := range account.Groups { if group.Issued == "" { - group.Issued = GroupIssuedAPI + group.Issued = nbgroup.GroupIssuedAPI } } diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index d8575a3bf..d53298d8f 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/util" ) @@ -188,7 +189,7 @@ func TestStore(t *testing.T) { Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } - account.Groups["all"] = &Group{ + account.Groups["all"] = &group.Group{ ID: "all", Name: "all", Peers: []string{"testpeer"}, @@ -320,7 +321,7 @@ func TestRestoreGroups_Migration(t *testing.T) { // create default group account := store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] - account.Groups = map[string]*Group{ + account.Groups = map[string]*group.Group{ "cfefqs706sqkneg59g3g": { ID: "cfefqs706sqkneg59g3g", Name: "All", @@ -336,7 +337,7 @@ func TestRestoreGroups_Migration(t *testing.T) { account = store.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] require.Contains(t, account.Groups, "cfefqs706sqkneg59g3g", "failed to restore a FileStore file - missing Account Groups") - require.Equal(t, GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") + require.Equal(t, group.GroupIssuedAPI, account.Groups["cfefqs706sqkneg59g3g"].Issued, "default group should has API issued mark") } func TestGetAccountByPrivateDomain(t *testing.T) { @@ -384,6 +385,7 @@ func TestFileStore_GetAccount(t *testing.T) { expected := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"] if expected == nil { t.Fatalf("expected account doesn't exist") + return } account, err := store.GetAccount(expected.Id) diff --git a/management/server/group.go b/management/server/group.go index 43d48e622..0fc952cdb 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -19,51 +20,8 @@ func (e *GroupLinkError) Error() string { return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) } -const ( - GroupIssuedAPI = "api" - GroupIssuedJWT = "jwt" - GroupIssuedIntegration = "integration" -) - -// Group of the peers for ACL -type Group struct { - // ID of the group - ID string - - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - - // Name visible in the UI - Name string - - // Issued defines how this group was created (enum of "api", "integration" or "jwt") - Issued string - - // Peers list of the group - Peers []string `gorm:"serializer:json"` - - IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` -} - -// EventMeta returns activity event meta related to the group -func (g *Group) EventMeta() map[string]any { - return map[string]any{"name": g.Name} -} - -func (g *Group) Copy() *Group { - group := &Group{ - ID: g.ID, - Name: g.Name, - Issued: g.Issued, - Peers: make([]string, len(g.Peers)), - IntegrationReference: g.IntegrationReference, - } - copy(group.Peers, g.Peers) - return group -} - // GetGroup object of the peers -func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) { +func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -72,6 +30,15 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er return nil, err } + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") + } + group, ok := account.Groups[groupID] if ok { return group, nil @@ -80,8 +47,8 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) } -// GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*Group, error) { +// GetAllGroups returns all groups in an account +func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -90,7 +57,34 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G return nil, err } - matchingGroups := make([]*Group, 0) + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") + } + + groups := make([]*nbgroup.Group, 0, len(account.Groups)) + for _, item := range account.Groups { + groups = append(groups, item) + } + + return groups, nil +} + +// GetGroupByName filters all groups in an account by name and returns the one with the most peers +func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, err + } + + matchingGroups := make([]*nbgroup.Group, 0) for _, group := range account.Groups { if group.Name == groupName { matchingGroups = append(matchingGroups, group) @@ -102,7 +96,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G } maxPeers := -1 - var groupWithMostPeers *Group + var groupWithMostPeers *nbgroup.Group for i, group := range matchingGroups { if len(group.Peers) > maxPeers { maxPeers = len(group.Peers) @@ -114,7 +108,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*G } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *Group) error { +func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -123,11 +117,11 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *G return err } - if newGroup.ID == "" && newGroup.Issued != GroupIssuedAPI { + if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) } - if newGroup.ID == "" && newGroup.Issued == GroupIssuedAPI { + if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { existingGroup, err := account.FindGroupByName(newGroup.Name) if err != nil { @@ -234,7 +228,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } // disable a deleting integration group if the initiator is not an admin service user - if g.Issued == GroupIssuedIntegration { + if g.Issued == nbgroup.GroupIssuedIntegration { executingUser := account.Users[userId] if executingUser == nil { return status.Errorf(status.NotFound, "user not found") @@ -304,6 +298,15 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } } + // check integrated peer validator groups + if account.Settings.Extra != nil { + for _, integratedPeerValidatorGroups := range account.Settings.Extra.IntegratedValidatorGroups { + if groupID == integratedPeerValidatorGroups { + return &GroupLinkError{"integrated validator", g.Name} + } + } + } + delete(account.Groups, groupID) account.Network.IncSerial() @@ -319,7 +322,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) } // ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) { +func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -328,7 +331,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) return nil, err } - groups := make([]*Group, 0, len(account.Groups)) + groups := make([]*nbgroup.Group, 0, len(account.Groups)) for _, item := range account.Groups { groups = append(groups, item) } diff --git a/management/server/group/group.go b/management/server/group/group.go new file mode 100644 index 000000000..79dfd995c --- /dev/null +++ b/management/server/group/group.go @@ -0,0 +1,46 @@ +package group + +import "github.com/netbirdio/netbird/management/server/integration_reference" + +const ( + GroupIssuedAPI = "api" + GroupIssuedJWT = "jwt" + GroupIssuedIntegration = "integration" +) + +// Group of the peers for ACL +type Group struct { + // ID of the group + ID string + + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + + // Name visible in the UI + Name string + + // Issued defines how this group was created (enum of "api", "integration" or "jwt") + Issued string + + // Peers list of the group + Peers []string `gorm:"serializer:json"` + + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` +} + +// EventMeta returns activity event meta related to the group +func (g *Group) EventMeta() map[string]any { + return map[string]any{"name": g.Name} +} + +func (g *Group) Copy() *Group { + group := &Group{ + ID: g.ID, + Name: g.Name, + Issued: g.Issued, + Peers: make([]string, len(g.Peers)), + IntegrationReference: g.IntegrationReference, + } + copy(group.Peers, g.Peers) + return group +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 3a2195c88..35e9b2170 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -5,6 +5,7 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" ) @@ -24,22 +25,22 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { t.Error("failed to init testing account") } for _, group := range account.Groups { - group.Issued = GroupIssuedIntegration + group.Issued = nbgroup.GroupIssuedIntegration err = am.SaveGroup(account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", GroupIssuedIntegration) + t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) } } for _, group := range account.Groups { - group.Issued = GroupIssuedJWT + group.Issued = nbgroup.GroupIssuedJWT err = am.SaveGroup(account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", GroupIssuedJWT) + t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) } } for _, group := range account.Groups { - group.Issued = GroupIssuedAPI + group.Issued = nbgroup.GroupIssuedAPI group.ID = "" err = am.SaveGroup(account.Id, groupAdminUserID, group) if err == nil { @@ -129,51 +130,51 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) { accountID := "testingAcc" domain := "example.com" - groupForRoute := &Group{ + groupForRoute := &nbgroup.Group{ ID: "grp-for-route", AccountID: "account-id", Name: "Group for route", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForNameServerGroups := &Group{ + groupForNameServerGroups := &nbgroup.Group{ ID: "grp-for-name-server-grp", AccountID: "account-id", Name: "Group for name server groups", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForPolicies := &Group{ + groupForPolicies := &nbgroup.Group{ ID: "grp-for-policies", AccountID: "account-id", Name: "Group for policies", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForSetupKeys := &Group{ + groupForSetupKeys := &nbgroup.Group{ ID: "grp-for-keys", AccountID: "account-id", Name: "Group for setup keys", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForUsers := &Group{ + groupForUsers := &nbgroup.Group{ ID: "grp-for-users", AccountID: "account-id", Name: "Group for users", - Issued: GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, Peers: make([]string, 0), } - groupForIntegration := &Group{ + groupForIntegration := &nbgroup.Group{ ID: "grp-for-integration", AccountID: "account-id", Name: "Group for users integration", - Issued: GroupIssuedIntegration, + Issued: nbgroup.GroupIssuedIntegration, Peers: make([]string, 0), } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 341d202b6..4df24711e 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -343,10 +343,18 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p userID := "" // JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered, // or it uses a setup key to register. + if loginReq.GetJwtToken() != "" { - userID, err = s.validateToken(loginReq.GetJwtToken()) + for i := 0; i < 3; i++ { + userID, err = s.validateToken(loginReq.GetJwtToken()) + if err == nil { + break + } + log.Warnf("failed validating JWT token sent from peer %s with error %v. "+ + "Trying again as it may be due to the IdP cache issue", peerKey, err) + time.Sleep(200 * time.Millisecond) + } if err != nil { - log.Warnf("failed validating JWT token sent from peer %s", peerKey) return nil, err } } @@ -361,6 +369,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p Meta: extractPeerMeta(loginReq), UserID: userID, SetupKey: loginReq.GetSetupKey(), + ConnectionIP: realIP, }) if err != nil { diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 71088cfaf..d3c9954d3 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -76,6 +76,7 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings := &server.Settings{ PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), + RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, } if req.Settings.Extra != nil { @@ -143,6 +144,7 @@ func toAccountResponse(account *server.Account) *api.Account { JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, JwtAllowGroups: &jwtAllowGroups, + RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked, } if account.Settings.Extra != nil { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index fd2c4bfcd..9d174d0be 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -69,6 +69,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { Settings: &server.Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour, + RegularUsersViewBlocked: true, }, }, adminUser) @@ -96,6 +97,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr(""), JwtGroupsEnabled: br(false), JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, }, expectedArray: true, expectedID: accountID, @@ -114,6 +116,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr(""), JwtGroupsEnabled: br(false), JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: false, }, expectedArray: false, expectedID: accountID, @@ -123,7 +126,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"]}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, @@ -132,6 +135,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr("roles"), JwtGroupsEnabled: br(true), JwtAllowGroups: &[]string{"test"}, + RegularUsersViewBlocked: true, }, expectedArray: false, expectedID: accountID, @@ -141,7 +145,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 554400, @@ -150,6 +154,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { JwtGroupsClaimName: sr("groups"), JwtGroupsEnabled: br(true), JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 6c22e52bf..23b14c0a7 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -54,6 +54,10 @@ components: description: Period of time after which peer login expires (seconds). type: integer example: 43200 + regular_users_view_blocked: + description: Allows blocking regular users from viewing parts of the system. + type: boolean + example: true groups_propagation_enabled: description: Allows propagate the new user auto groups to peers that belongs to the user type: boolean @@ -77,6 +81,7 @@ components: required: - peer_login_expiration_enabled - peer_login_expiration + - regular_users_view_blocked AccountExtraSettings: type: object properties: @@ -144,6 +149,8 @@ components: description: How user was issued by API or Integration type: string example: api + permissions: + $ref: '#/components/schemas/UserPermissions' required: - id - email @@ -152,6 +159,14 @@ components: - auto_groups - status - is_blocked + UserPermissions: + type: object + properties: + dashboard_view: + description: User's permission to view the dashboard + type: string + enum: [ "limited", "blocked", "full" ] + example: limited UserRequest: type: object properties: @@ -340,6 +355,7 @@ components: - user_id - version - ui_version + - approval_required AccessiblePeer: allOf: - $ref: '#/components/schemas/PeerMinimum' @@ -589,8 +605,6 @@ components: type: string enum: ["api", "integration", "jwt"] example: api - type: string - example: api required: - id - name diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index d10b28c2b..a141603ff 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -69,6 +69,20 @@ const ( GeoLocationCheckActionDeny GeoLocationCheckAction = "deny" ) +// Defines values for GroupIssued. +const ( + GroupIssuedApi GroupIssued = "api" + GroupIssuedIntegration GroupIssued = "integration" + GroupIssuedJwt GroupIssued = "jwt" +) + +// Defines values for GroupMinimumIssued. +const ( + GroupMinimumIssuedApi GroupMinimumIssued = "api" + GroupMinimumIssuedIntegration GroupMinimumIssued = "integration" + GroupMinimumIssuedJwt GroupMinimumIssued = "jwt" +) + // Defines values for NameserverNsType. const ( NameserverNsTypeUdp NameserverNsType = "udp" @@ -129,6 +143,13 @@ const ( UserStatusInvited UserStatus = "invited" ) +// Defines values for UserPermissionsDashboardView. +const ( + UserPermissionsDashboardViewBlocked UserPermissionsDashboardView = "blocked" + UserPermissionsDashboardViewFull UserPermissionsDashboardView = "full" + UserPermissionsDashboardViewLimited UserPermissionsDashboardView = "limited" +) + // AccessiblePeer defines model for AccessiblePeer. type AccessiblePeer struct { // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud @@ -186,6 +207,9 @@ type AccountSettings struct { // PeerLoginExpirationEnabled Enables or disables peer login expiration globally. After peer's login has expired the user has to log in (authenticate). Applies only to peers that were added by a user (interactive SSO login). PeerLoginExpirationEnabled bool `json:"peer_login_expiration_enabled"` + + // RegularUsersViewBlocked Allows blocking regular users from viewing parts of the system. + RegularUsersViewBlocked bool `json:"regular_users_view_blocked"` } // Checks List of objects that perform the actual checks @@ -286,8 +310,8 @@ type Group struct { // Id Group ID Id string `json:"id"` - // Issued How group was issued by API or from JWT token - Issued *string `json:"issued,omitempty"` + // Issued How the group was issued (api, integration, jwt) + Issued *GroupIssued `json:"issued,omitempty"` // Name Group Name identifier Name string `json:"name"` @@ -299,13 +323,16 @@ type Group struct { PeersCount int `json:"peers_count"` } +// GroupIssued How the group was issued (api, integration, jwt) +type GroupIssued string + // GroupMinimum defines model for GroupMinimum. type GroupMinimum struct { // Id Group ID Id string `json:"id"` - // Issued How group was issued by API or from JWT token - Issued *string `json:"issued,omitempty"` + // Issued How the group was issued (api, integration, jwt) + Issued *GroupMinimumIssued `json:"issued,omitempty"` // Name Group Name identifier Name string `json:"name"` @@ -314,6 +341,9 @@ type GroupMinimum struct { PeersCount int `json:"peers_count"` } +// GroupMinimumIssued How the group was issued (api, integration, jwt) +type GroupMinimumIssued string + // GroupRequest defines model for GroupRequest. type GroupRequest struct { // Name Group name identifier @@ -443,7 +473,7 @@ type Peer struct { AccessiblePeers []AccessiblePeer `json:"accessible_peers"` // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` @@ -512,7 +542,7 @@ type Peer struct { // PeerBase defines model for PeerBase. type PeerBase struct { // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` @@ -584,7 +614,7 @@ type PeerBatch struct { AccessiblePeersCount int `json:"accessible_peers_count"` // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired *bool `json:"approval_required,omitempty"` + ApprovalRequired bool `json:"approval_required"` // CityName Commonly used English name of the city CityName CityName `json:"city_name"` @@ -1089,7 +1119,8 @@ type User struct { LastLogin *time.Time `json:"last_login,omitempty"` // Name User's name from idp provider - Name string `json:"name"` + Name string `json:"name"` + Permissions *UserPermissions `json:"permissions,omitempty"` // Role User's NetBird account role Role string `json:"role"` @@ -1119,6 +1150,15 @@ type UserCreateRequest struct { Role string `json:"role"` } +// UserPermissions defines model for UserPermissions. +type UserPermissions struct { + // DashboardView User's permission to view the dashboard + DashboardView *UserPermissionsDashboardView `json:"dashboard_view,omitempty"` +} + +// UserPermissionsDashboardView User's permission to view the dashboard +type UserPermissionsDashboardView string + // UserRequest defines model for UserRequest. type UserRequest struct { // AutoGroups Group IDs to auto-assign to peers registered by this user diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index b37f4fd2f..47bcf2f32 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -4,15 +4,15 @@ import ( "encoding/json" "net/http" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/util" - "github.com/netbirdio/netbird/management/server/status" - - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" ) // GroupsHandler is a handler that returns groups of the account @@ -35,19 +35,25 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr // GetAllGroups list for the account func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, _, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - var groups []*api.Group - for _, g := range account.Groups { - groups = append(groups, toGroupResponse(account, g)) + groups, err := h.accountManager.GetAllGroups(account.Id, user.Id) + if err != nil { + util.WriteError(err, w) + return } - util.WriteJSONObject(w, groups) + groupsResponse := make([]*api.Group, 0, len(groups)) + for _, group := range groups { + groupsResponse = append(groupsResponse, toGroupResponse(account, group)) + } + + util.WriteJSONObject(w, groupsResponse) } // UpdateGroup handles update to a group identified by a given ID @@ -104,7 +110,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := server.Group{ + group := nbgroup.Group{ ID: groupID, Name: req.Name, Peers: peers, @@ -148,10 +154,10 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := server.Group{ + group := nbgroup.Group{ Name: req.Name, Peers: peers, - Issued: server.GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, } err = h.accountManager.SaveGroup(account.Id, user.Id, &group) @@ -207,7 +213,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { // GetGroup returns a group func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, _, err := h.accountManager.GetAccountFromToken(claims) + account, user, err := h.accountManager.GetAccountFromToken(claims) if err != nil { util.WriteError(err, w) return @@ -221,7 +227,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { return } - group, err := h.accountManager.GetGroup(account.Id, groupID) + group, err := h.accountManager.GetGroup(account.Id, groupID, user.Id) if err != nil { util.WriteError(err, w) return @@ -234,12 +240,12 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { } } -func toGroupResponse(account *server.Account, group *server.Group) *api.Group { +func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { cache := make(map[string]api.PeerMinimum) gr := api.Group{ Id: group.ID, Name: group.Name, - Issued: &group.Issued, + Issued: (*api.GroupIssued)(&group.Issued), } for _, pid := range group.Peers { diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index 5b47b1208..3d74b848c 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -15,6 +15,7 @@ import ( "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -28,30 +29,30 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandler { +func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { return &GroupsHandler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(accountID, userID string, group *server.Group) error { + SaveGroupFunc: func(accountID, userID string, group *nbgroup.Group) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } return nil }, - GetGroupFunc: func(_, groupID string) (*server.Group, error) { + GetGroupFunc: func(_, groupID, _ string) (*nbgroup.Group, error) { if groupID != "idofthegroup" { return nil, status.Errorf(status.NotFound, "not found") } if groupID == "id-jwt-group" { - return &server.Group{ + return &nbgroup.Group{ ID: "id-jwt-group", Name: "Default Group", - Issued: server.GroupIssuedJWT, + Issued: nbgroup.GroupIssuedJWT, }, nil } - return &server.Group{ + return &nbgroup.Group{ ID: "idofthegroup", Name: "Group", - Issued: server.GroupIssuedAPI, + Issued: nbgroup.GroupIssuedAPI, }, nil }, GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { @@ -62,10 +63,10 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle Users: map[string]*server.User{ user.Id: user, }, - Groups: map[string]*server.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: server.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: server.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: server.GroupIssuedAPI}, + Groups: map[string]*nbgroup.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, }, }, user, nil }, @@ -118,7 +119,7 @@ func TestGetGroup(t *testing.T) { }, } - group := &server.Group{ + group := &nbgroup.Group{ ID: "idofthegroup", Name: "Group", } @@ -153,7 +154,7 @@ func TestGetGroup(t *testing.T) { t.Fatalf("I don't know what I expected; %v", err) } - got := &server.Group{} + got := &nbgroup.Group{} if err = json.Unmarshal(content, &got); err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) } @@ -187,7 +188,7 @@ func TestWriteGroup(t *testing.T) { expectedGroup: &api.Group{ Id: "id-was-set", Name: "Default POSTed Group", - Issued: &groupIssuedAPI, + Issued: (*api.GroupIssued)(&groupIssuedAPI), }, }, { @@ -209,7 +210,7 @@ func TestWriteGroup(t *testing.T) { expectedGroup: &api.Group{ Id: "id-existed", Name: "Default POSTed Group", - Issued: &groupIssuedAPI, + Issued: (*api.GroupIssued)(&groupIssuedAPI), }, }, { @@ -240,7 +241,7 @@ func TestWriteGroup(t *testing.T) { expectedGroup: &api.Group{ Id: "id-jwt-group", Name: "changed", - Issued: &groupIssuedJWT, + Issued: (*api.GroupIssued)(&groupIssuedJWT), }, }, } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index d035ae0b7..bdbeba346 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -9,7 +9,6 @@ import ( "github.com/rs/cors" "github.com/netbirdio/management-integrations/integrations" - s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/middleware" diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index d4d2558e8..77b4578f8 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -6,8 +6,10 @@ import ( "net/http" "github.com/gorilla/mux" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -61,10 +63,18 @@ func (h *PeersHandler) getPeer(account *server.Account, peerID, userID string, w groupsInfo := toGroupsInfo(account.Groups, peer.ID) - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) + validPeers, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) + return + } + + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) - util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers)) + _, valid := validPeers[peer.ID] + util.WriteJSONObject(w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) } func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { @@ -75,11 +85,18 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe return } - update := &nbpeer.Peer{ID: peerID, SSHEnabled: req.SshEnabled, Name: req.Name, - LoginExpirationEnabled: req.LoginExpirationEnabled} + update := &nbpeer.Peer{ + ID: peerID, + SSHEnabled: req.SshEnabled, + Name: req.Name, + LoginExpirationEnabled: req.LoginExpirationEnabled, + } if req.ApprovalRequired != nil { - update.Status = &nbpeer.PeerStatus{RequiresApproval: *req.ApprovalRequired} + // todo: looks like that we reset all status property, is it right? + update.Status = &nbpeer.PeerStatus{ + RequiresApproval: *req.ApprovalRequired, + } } peer, err := h.accountManager.UpdatePeer(account.Id, user.Id, update) @@ -91,15 +108,24 @@ func (h *PeersHandler) updatePeer(account *server.Account, user *server.User, pe groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) + validPeers, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) + return + } + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validPeers) accessiblePeers := toAccessiblePeers(netMap, dnsDomain) - util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers)) + _, valid := validPeers[peer.ID] + + util.WriteJSONObject(w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) } func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w http.ResponseWriter) { err := h.accountManager.DeletePeer(accountID, peerID, userID) if err != nil { + log.Errorf("failed to delete peer: %v", err) util.WriteError(err, w) return } @@ -138,46 +164,68 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { // GetAllPeers returns a list of all peers associated with a provided account func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(claims) - if err != nil { - util.WriteError(err, w) - return - } - - peers, err := h.accountManager.GetPeers(account.Id, user.Id) - if err != nil { - util.WriteError(err, w) - return - } - - dnsDomain := h.accountManager.GetDNSDomain() - - respBody := make([]*api.PeerBatch, 0, len(peers)) - for _, peer := range peers { - peerToReturn, err := h.checkPeerStatus(peer) - if err != nil { - util.WriteError(err, w) - return - } - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) - - accessiblePeerNumbers := h.accessiblePeersNumber(account, peer.ID) - - respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) - } - util.WriteJSONObject(w, respBody) - return - default: + if r.Method != http.MethodGet { util.WriteError(status.Errorf(status.NotFound, "unknown METHOD"), w) + return } + + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + peers, err := h.accountManager.GetPeers(account.Id, user.Id) + if err != nil { + util.WriteError(err, w) + return + } + + dnsDomain := h.accountManager.GetDNSDomain() + + respBody := make([]*api.PeerBatch, 0, len(peers)) + for _, peer := range peers { + peerToReturn, err := h.checkPeerStatus(peer) + if err != nil { + util.WriteError(err, w) + return + } + groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + + accessiblePeerNumbers, _ := h.accessiblePeersNumber(account, peer.ID) + + respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers)) + } + + validPeersMap, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed to list appreoved peers: %v", err) + util.WriteError(fmt.Errorf("internal error"), w) + return + } + h.setApprovalRequiredFlag(respBody, validPeersMap) + + util.WriteJSONObject(w, respBody) } -func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) int { - netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain()) - return len(netMap.Peers) + len(netMap.OfflinePeers) +func (h *PeersHandler) accessiblePeersNumber(account *server.Account, peerID string) (int, error) { + validatedPeersMap, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + return 0, err + } + + netMap := account.GetPeerNetworkMap(peerID, h.accountManager.GetDNSDomain(), validatedPeersMap) + return len(netMap.Peers) + len(netMap.OfflinePeers), nil +} + +func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { + for _, peer := range respBody { + _, ok := approvedPeersMap[peer.Id] + if !ok { + peer.ApprovalRequired = true + } + } } func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { @@ -206,7 +254,7 @@ func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.Access return accessiblePeers } -func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMinimum { +func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { var groupsInfo []api.GroupMinimum groupsChecked := make(map[string]struct{}) for _, group := range groups { @@ -230,7 +278,7 @@ func toGroupsInfo(groups map[string]*server.Group, peerID string) []api.GroupMin return groupsInfo } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core @@ -257,7 +305,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD LastLogin: peer.LastLogin, LoginExpired: peer.Status.LoginExpired, AccessiblePeers: accessiblePeer, - ApprovalRequired: &peer.Status.RequiresApproval, + ApprovalRequired: !approved, CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, } @@ -290,7 +338,6 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn LastLogin: peer.LastLogin, LoginExpired: peer.Status.LoginExpired, AccessiblePeersCount: accessiblePeersCount, - ApprovalRequired: &peer.Status.RequiresApproval, CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, } diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index e6b858036..74e682854 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" @@ -51,7 +52,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { Policies: []*server.Policy{ {ID: "id-existed"}, }, - Groups: map[string]*server.Group{ + Groups: map[string]*nbgroup.Group{ "F": {ID: "F"}, "G": {ID: "G"}, }, diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index 7b68479ed..ebbd5954f 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -13,13 +13,12 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/status" - - "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/status" ) const ( @@ -44,7 +43,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup SetupKeys: map[string]*server.SetupKey{ defaultKey.Key: defaultKey, }, - Groups: map[string]*server.Group{ + Groups: map[string]*nbgroup.Group{ "group-1": {ID: "group-1", Peers: []string{"A", "B"}}, "id-all": {ID: "id-all", Name: "All"}, }, diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 5d92b65e5..ed8a3f543 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -288,5 +288,8 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { IsBlocked: user.IsBlocked, LastLogin: &user.LastLogin, Issued: &user.Issued, + Permissions: &api.UserPermissions{ + DashboardView: (*api.UserPermissionsDashboardView)(&user.Permissions.DashboardView), + }, } } diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index ff886ca9f..91f19d8d8 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -105,7 +105,7 @@ func initUsersTestData() *UsersHandler { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } - info, err := update.Copy().ToUserInfo(nil) + info, err := update.Copy().ToUserInfo(nil, &server.Settings{RegularUsersViewBlocked: false}) if err != nil { return nil, err } diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 4e2c3d0b3..2bb279c76 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -99,6 +99,8 @@ func WriteError(err error, w http.ResponseWriter) { httpStatus = http.StatusUnprocessableEntity case status.Unauthorized: httpStatus = http.StatusUnauthorized + case status.BadRequest: + httpStatus = http.StatusBadRequest default: } msg = strings.ToLower(err.Error()) diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index 706e4d330..2f21b3b54 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -115,7 +115,15 @@ func (ac *AzureCredentials) requestJWTToken() (*http.Response, error) { data.Set("client_id", ac.clientConfig.ClientID) data.Set("client_secret", ac.clientConfig.ClientSecret) data.Set("grant_type", ac.clientConfig.GrantType) - data.Set("scope", "https://graph.microsoft.com/.default") + parsedURL, err := url.Parse(ac.clientConfig.GraphAPIEndpoint) + if err != nil { + return nil, err + } + + // get base url and add "/.default" as scope + baseURL := parsedURL.Scheme + "://" + parsedURL.Host + scopeURL := baseURL + "/.default" + data.Set("scope", scopeURL) payload := strings.NewReader(data.Encode()) req, err := http.NewRequest(http.MethodPost, ac.clientConfig.TokenEndpoint, payload) diff --git a/management/server/idp/okta.go b/management/server/idp/okta.go index d20ee7e48..c8d33a207 100644 --- a/management/server/idp/okta.go +++ b/management/server/idp/okta.go @@ -273,7 +273,7 @@ func (om *OktaManager) DeleteUser(userID string) error { return nil } -// parseOktaUserToUserData parse okta user to UserData. +// parseOktaUser parse okta user to UserData. func parseOktaUser(user *okta.User) (*UserData, error) { var oktaUser struct { Email string `json:"email"` diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go new file mode 100644 index 000000000..cd770a801 --- /dev/null +++ b/management/server/integrated_validator.go @@ -0,0 +1,80 @@ +package server + +import ( + "errors" + + "github.com/google/martian/v3/log" + + "github.com/netbirdio/netbird/management/server/account" +) + +// UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account. +// It retrieves the account associated with the provided userID, then updates the integrated validator groups +// with the provided list of group ids. The updated account is then saved. +// +// Parameters: +// - accountID: The ID of the account for which integrated validator groups are to be updated. +// - userID: The ID of the user whose account is being updated. +// - groups: A slice of strings representing the ids of integrated validator groups to be updated. +// +// Returns: +// - error: An error if any occurred during the process, otherwise returns nil +func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { + ok, err := am.GroupValidation(accountID, groups) + if err != nil { + log.Debugf("error validating groups: %s", err.Error()) + return err + } + + if !ok { + log.Debugf("invalid groups") + return errors.New("invalid groups") + } + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + a, err := am.Store.GetAccountByUser(userID) + if err != nil { + return err + } + + var extra *account.ExtraSettings + + if a.Settings.Extra != nil { + extra = a.Settings.Extra + } else { + extra = &account.ExtraSettings{} + a.Settings.Extra = extra + } + extra.IntegratedValidatorGroups = groups + return am.Store.SaveAccount(a) +} + +func (am *DefaultAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { + if len(groups) == 0 { + return true, nil + } + accountsGroups, err := am.ListGroups(accountId) + if err != nil { + return false, err + } + for _, group := range groups { + var found bool + for _, accountGroup := range accountsGroups { + if accountGroup.ID == group { + found = true + break + } + } + if !found { + return false, nil + } + } + + return true, nil +} + +func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { + return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) +} diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go new file mode 100644 index 000000000..e87755b87 --- /dev/null +++ b/management/server/integrated_validator/interface.go @@ -0,0 +1,19 @@ +package integrated_validator + +import ( + "github.com/netbirdio/netbird/management/server/account" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" +) + +// IntegratedValidator interface exists to avoid the circle dependencies +type IntegratedValidator interface { + ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error + ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) + PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer + IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) + GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) + PeerDeleted(accountID, peerID string) error + SetPeerInvalidationListener(fn func(accountID string)) + Stop() +} diff --git a/management/server/integration_reference/integration_reference.go b/management/server/integration_reference/integration_reference.go new file mode 100644 index 000000000..254b4e62f --- /dev/null +++ b/management/server/integration_reference/integration_reference.go @@ -0,0 +1,23 @@ +package integration_reference + +import ( + "fmt" + "strings" +) + +// IntegrationReference holds the reference to a particular integration +type IntegrationReference struct { + ID int + IntegrationType string +} + +func (ir IntegrationReference) String() string { + return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID) +} + +func (ir IntegrationReference) CacheKey(path ...string) string { + if len(path) == 0 { + return ir.String() + } + return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":")) +} diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 6ea902003..98ad0de0c 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -9,8 +9,6 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/management/server/activity" - "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -19,6 +17,7 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/util" ) @@ -413,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error) peersUpdateManager := NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false) + eventStore, nil, false, MocIntegratedValidator{}) if err != nil { return nil, "", err } diff --git a/management/server/management_test.go b/management/server/management_test.go index fb3f74cb9..13db5ae95 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -10,24 +10,22 @@ import ( sync2 "sync" "time" - "github.com/netbirdio/netbird/management/server/activity" - - "google.golang.org/grpc/credentials/insecure" - - "github.com/netbirdio/netbird/management/server" - pb "github.com/golang/protobuf/proto" //nolint - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/encryption" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" + "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/util" ) @@ -448,6 +446,43 @@ var _ = Describe("Management service", func() { }) }) +type MocIntegratedValidator struct { +} + +func (a MocIntegratedValidator) ValidateExtraSettings(newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, error) { + return update, nil +} + +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for p := range peers { + validatedPeers[p] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool) { + return false, false +} + +func (MocIntegratedValidator) PeerDeleted(_, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + +} + +func (MocIntegratedValidator) Stop() {} + func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { defer GinkgoRecover() @@ -504,7 +539,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { peersUpdateManager := server.NewPeersUpdateManager(nil) eventStore := &activity.InMemoryEventStore{} accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", - eventStore, nil, false) + eventStore, nil, false, MocIntegratedValidator{}) if err != nil { log.Fatalf("failed creating a manager: %v", err) } diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 7be3c818d..c479867d2 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,6 +5,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/route" @@ -32,7 +33,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account { UsedTimes: 1, }, }, - Groups: map[string]*server.Group{ + Groups: map[string]*group.Group{ "1": {}, "2": {}, }, @@ -117,7 +118,7 @@ func (mockDatasource) GetAllAccounts() []*server.Account { UsedTimes: 1, }, }, - Groups: map[string]*server.Group{ + Groups: map[string]*group.Group{ "1": {}, "2": {}, }, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index f518372ed..8687937dc 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -10,6 +10,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -31,11 +32,12 @@ type MockAccountManager struct { GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error) AddPeerFunc func(setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, error) - GetGroupFunc func(accountID, groupID string) (*server.Group, error) - GetGroupByNameFunc func(accountID, groupName string) (*server.Group, error) - SaveGroupFunc func(accountID, userID string, group *server.Group) error + GetGroupFunc func(accountID, groupID, userID string) (*group.Group, error) + GetAllGroupsFunc func(accountID, userID string) ([]*group.Group, error) + GetGroupByNameFunc func(accountID, groupName string) (*group.Group, error) + SaveGroupFunc func(accountID, userID string, group *group.Group) error DeleteGroupFunc func(accountID, userId, groupID string) error - ListGroupsFunc func(accountID string) ([]*server.Group, error) + ListGroupsFunc func(accountID string) ([]*group.Group, error) GroupAddPeerFunc func(accountID, groupID, peerID string) error GroupDeletePeerFunc func(accountID, groupID, peerID string) error DeleteRuleFunc func(accountID, ruleID, userID string) error @@ -90,6 +92,32 @@ type MockAccountManager struct { DeletePostureChecksFunc func(accountID, postureChecksID, userID string) error ListPostureChecksFunc func(accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager + UpdateIntegratedValidatorGroupsFunc func(accountID string, userID string, groups []string) error + GroupValidationFunc func(accountId string, groups []string) (bool, error) +} + +func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { + approvedPeers := make(map[string]struct{}) + for id := range account.Peers { + approvedPeers[id] = struct{}{} + } + return approvedPeers, nil +} + +// GetGroup mock implementation of GetGroup from server.AccountManager interface +func (am *MockAccountManager) GetGroup(accountId, groupID, userID string) (*group.Group, error) { + if am.GetGroupFunc != nil { + return am.GetGroupFunc(accountId, groupID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetGroup is not implemented") +} + +// GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface +func (am *MockAccountManager) GetAllGroups(accountID, userID string) ([]*group.Group, error) { + if am.GetAllGroupsFunc != nil { + return am.GetAllGroupsFunc(accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAllGroups is not implemented") } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface @@ -243,16 +271,8 @@ func (am *MockAccountManager) AddPeer( return nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented") } -// GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(accountID, groupID string) (*server.Group, error) { - if am.GetGroupFunc != nil { - return am.GetGroupFunc(accountID, groupID) - } - return nil, status.Errorf(codes.Unimplemented, "method GetGroup is not implemented") -} - // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*server.Group, error) { +func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*group.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupByNameFunc(accountID, groupName) } @@ -260,7 +280,7 @@ func (am *MockAccountManager) GetGroupByName(accountID, groupName string) (*serv } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server.Group) error { +func (am *MockAccountManager) SaveGroup(accountID, userID string, group *group.Group) error { if am.SaveGroupFunc != nil { return am.SaveGroupFunc(accountID, userID, group) } @@ -276,7 +296,7 @@ func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) err } // ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(accountID string) ([]*server.Group, error) { +func (am *MockAccountManager) ListGroups(accountID string) ([]*group.Group, error) { if am.ListGroupsFunc != nil { return am.ListGroupsFunc(accountID) } @@ -685,3 +705,19 @@ func (am *MockAccountManager) GetIdpManager() idp.Manager { } return nil } + +// UpdateIntegratedValidatorGroups mocks UpdateIntegratedApprovalGroups of the AccountManager interface +func (am *MockAccountManager) UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error { + if am.UpdateIntegratedValidatorGroupsFunc != nil { + return am.UpdateIntegratedValidatorGroupsFunc(accountID, userID, groups) + } + return status.Errorf(codes.Unimplemented, "method UpdateIntegratedValidatorGroups is not implemented") +} + +// GroupValidation mocks GroupValidation of the AccountManager interface +func (am *MockAccountManager) GroupValidation(accountId string, groups []string) (bool, error) { + if am.GroupValidationFunc != nil { + return am.GroupValidationFunc(accountId, groups) + } + return false, status.Errorf(codes.Unimplemented, "method GroupValidation is not implemented") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index e521805c8..fa7793602 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -10,6 +10,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -261,7 +262,7 @@ func validateNSList(list []nbdns.NameServer) error { return nil } -func validateGroups(list []string, groups map[string]*Group) error { +func validateGroups(list []string, groups map[string]*nbgroup.Group) error { if len(list) == 0 { return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index d04ac1a20..b10f9387a 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -8,6 +8,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -759,7 +760,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createNSStore(t *testing.T) (Store, error) { @@ -831,12 +832,12 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - newGroup1 := &Group{ + newGroup1 := &nbgroup.Group{ ID: group1ID, Name: group1ID, } - newGroup2 := &Group{ + newGroup2 := &nbgroup.Group{ ID: group2ID, Name: group2ID, } diff --git a/management/server/peer.go b/management/server/peer.go index 53b86e9b3..fda8e49e9 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -7,16 +7,12 @@ import ( "time" "github.com/rs/xid" - - "github.com/netbirdio/management-integrations/additions" - - "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/activity" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/status" ) // PeerSync used as a data object between the gRPC API and AccountManager on Sync request. @@ -37,6 +33,8 @@ type PeerLogin struct { UserID string // SetupKey references to a server.SetupKey to log in. Can be empty when UserID is used or auth is not required. SetupKey string + // ConnectionIP is the real IP of the peer + ConnectionIP net.IP } // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if @@ -52,8 +50,17 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P return nil, err } + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, err + } peers := make([]*nbpeer.Peer, 0) peersMap := make(map[string]*nbpeer.Peer) + + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return peers, nil + } + for _, peer := range account.Peers { if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != peer.UserID { // only display peers that belong to the current user if the current user is not an admin @@ -66,7 +73,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.getPeerConnectionResources(peer.ID) + aclPeers, _ := account.getPeerConnectionResources(peer.ID, approvedPeersMap) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -162,7 +169,7 @@ func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nb return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) } - update, err = additions.ValidatePeersUpdateRequest(update, peer, userID, accountID, am.eventStore, am.GetDNSDomain()) + update, err = am.integratedPeerValidator.ValidatePeer(update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { return nil, err } @@ -239,6 +246,12 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, // the 2nd loop performs the actual modification for _, peer := range peers { + + err := am.integratedPeerValidator.PeerDeleted(account.Id, peer.ID) + if err != nil { + return err + } + account.DeletePeer(peer.ID) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{ @@ -299,7 +312,17 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro if peer == nil { return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) } - return account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + groups := make(map[string][]string) + for groupID, group := range account.Groups { + groups[groupID] = group.Peers + } + + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) + if err != nil { + return nil, err + } + return account.GetPeerNetworkMap(peer.ID, am.dnsDomain, validatedPeers), nil } // GetPeerNetwork returns the Network for a given peer @@ -428,10 +451,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P CreatedAt: registrationTime, LoginExpirationEnabled: addedByUser, Ephemeral: ephemeral, - } - - if account.Settings.Extra != nil { - newPeer = additions.PreparePeer(newPeer, account.Settings.Extra) + Location: peer.Location, } // add peer to 'All' group @@ -462,6 +482,8 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P } } + newPeer = am.integratedPeerValidator.PreparePeer(account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) + if addedByUser { user, err := account.FindUser(userID) if err != nil { @@ -487,7 +509,11 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P am.updateAccountPeers(account) - networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain) + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + networkMap := account.GetPeerNetworkMap(newPeer.ID, am.dnsDomain, approvedPeersMap) return newPeer, networkMap, nil } @@ -524,23 +550,53 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *Network if peerLoginExpired(peer, account) { return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + requiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + if requiresApproval { + emptyMap := &NetworkMap{ + Network: account.Network.Copy(), + } + return peer, emptyMap, nil + } + + if isStatusChanged { + am.updateAccountPeers(account) + } + + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil } // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) { account, err := am.Store.GetAccountByPeerPubKey(login.WireGuardPubKey) - if err != nil { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. - return am.AddPeer(login.SetupKey, login.UserID, &nbpeer.Peer{ + newPeer := &nbpeer.Peer{ Key: login.WireGuardPubKey, Meta: login.Meta, SSHKey: login.SSHKey, - }) + } + if am.geo != nil && login.ConnectionIP != nil { + location, err := am.geo.Lookup(login.ConnectionIP) + if err != nil { + log.Warnf("failed to get location for new peer realip: [%s]: %v", login.ConnectionIP.String(), err) + } else { + newPeer.Location.ConnectionIP = login.ConnectionIP + newPeer.Location.CountryCode = location.Country.ISOCode + newPeer.Location.CityName = location.City.Names.En + newPeer.Location.GeoNameID = location.City.GeonameID + + } + } + + return am.AddPeer(login.SetupKey, login.UserID, newPeer) } log.Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) return nil, nil, status.Errorf(status.Internal, "failed while logging in peer") @@ -590,6 +646,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw am.StoreEvent(login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) } + isRequiresApproval, isStatusChanged := am.integratedPeerValidator.IsNotValidPeer(account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) peer, updated := updatePeerMeta(peer, login.Meta, account) if updated { shouldStoreAccount = true @@ -607,10 +664,23 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw } } - if updateRemotePeers { + if updateRemotePeers || isStatusChanged { am.updateAccountPeers(account) } - return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil + + if isRequiresApproval { + emptyMap := &NetworkMap{ + Network: account.Network.Copy(), + } + return peer, emptyMap, nil + } + + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, nil, err + } + + return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap), nil } func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { @@ -738,6 +808,10 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp return nil, err } + if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) + } + peer := account.GetPeer(peerID) if peer == nil { return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID) @@ -755,8 +829,13 @@ func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbp return nil, err } + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + return nil, err + } + for _, p := range userPeers { - aclPeers, _ := account.getPeerConnectionResources(p.ID) + aclPeers, _ := account.getPeerConnectionResources(p.ID, approvedPeersMap) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -780,8 +859,13 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco func (am *DefaultAccountManager) updateAccountPeers(account *Account) { peers := account.GetPeers() + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + log.Errorf("failed send out updates to peers, failed to validate peer: %v", err) + return + } for _, peer := range peers { - remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain) + remotePeerNetworkMap := account.GetPeerNetworkMap(peer.ID, am.dnsDomain, approvedPeersMap) update := toSyncResponse(nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain()) am.peersUpdateManager.SendUpdate(peer.ID, &UpdateMessage{Update: update}) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index ee84ea47d..6063cc2a7 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -4,11 +4,11 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/rs/xid" + "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" ) @@ -200,8 +200,8 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } var ( - group1 Group - group2 Group + group1 nbgroup.Group + group2 nbgroup.Group policy Policy ) @@ -392,6 +392,8 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { Id: someUser, Role: UserRoleUser, } + account.Settings.RegularUsersViewBlocked = false + err = manager.Store.SaveAccount(account) if err != nil { t.Fatal(err) @@ -480,3 +482,153 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } assert.NotNil(t, peer) } + +func TestDefaultAccountManager_GetPeers(t *testing.T) { + testCases := []struct { + name string + role UserRole + limitedViewSettings bool + isServiceUser bool + expectedPeerCount int + }{ + { + name: "Regular user, no limited view settings, not a service user", + role: UserRoleUser, + limitedViewSettings: false, + isServiceUser: false, + expectedPeerCount: 1, + }, + { + name: "Service user, no limited view settings", + role: UserRoleUser, + limitedViewSettings: false, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Regular user, limited view settings", + role: UserRoleUser, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 0, + }, + { + name: "Service user, limited view settings", + role: UserRoleUser, + limitedViewSettings: true, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Admin, no limited view settings, not a service user", + role: UserRoleAdmin, + limitedViewSettings: false, + isServiceUser: false, + expectedPeerCount: 2, + }, + { + name: "Admin service user, no limited view settings", + role: UserRoleAdmin, + limitedViewSettings: false, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Admin, limited view settings", + role: UserRoleAdmin, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 2, + }, + { + name: "Admin Service user, limited view settings", + role: UserRoleAdmin, + limitedViewSettings: true, + isServiceUser: true, + expectedPeerCount: 2, + }, + { + name: "Owner, no limited view settings", + role: UserRoleOwner, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 2, + }, + { + name: "Owner, limited view settings", + role: UserRoleOwner, + limitedViewSettings: true, + isServiceUser: false, + expectedPeerCount: 2, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + // account with an admin and a regular user + accountID := "test_account" + adminUser := "account_creator" + someUser := "some_user" + account := newAccountWithId(accountID, adminUser, "") + account.Users[someUser] = &User{ + Id: someUser, + Role: testCase.role, + IsServiceUser: testCase.isServiceUser, + } + account.Policies = []*Policy{} + account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings + + err = manager.Store.SaveAccount(account) + if err != nil { + t.Fatal(err) + return + } + + peerKey1, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + peerKey2, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + _, _, err = manager.AddPeer("", someUser, &nbpeer.Peer{ + Key: peerKey1.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"}, + }) + if err != nil { + t.Errorf("expecting peer to be added, got failure %v", err) + return + } + + _, _, err = manager.AddPeer("", adminUser, &nbpeer.Peer{ + Key: peerKey2.PublicKey().String(), + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"}, + }) + if err != nil { + t.Errorf("expecting peer to be added, got failure %v", err) + return + } + + peers, err := manager.GetPeers(accountID, someUser) + if err != nil { + t.Fatal(err) + return + } + assert.NotNil(t, peers) + + assert.Len(t, peers, testCase.expectedPeerCount) + + }) + } + +} diff --git a/management/server/policy.go b/management/server/policy.go index 8265dabb5..e162d2b3b 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -5,11 +5,11 @@ import ( "strconv" "strings" - "github.com/netbirdio/management-integrations/additions" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -211,7 +211,8 @@ type FirewallRule struct { // getPeerConnectionResources for a given peer // // This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []*FirewallRule) { +func (a *Account) getPeerConnectionResources(peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { + generateResources, getAccumulatedResources := a.connResourcesGenerator() for _, policy := range a.Policies { if !policy.Enabled { @@ -223,10 +224,8 @@ func (a *Account) getPeerConnectionResources(peerID string) ([]*nbpeer.Peer, []* continue } - sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks) - destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil) - sourcePeers = additions.ValidatePeers(sourcePeers) - destinationPeers = additions.ValidatePeers(destinationPeers) + sourcePeers, peerInSources := getAllPeersFromGroups(a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := getAllPeersFromGroups(a, rule.Destinations, peerID, nil, validatedPeersMap) if rule.Bidirectional { if peerInSources { @@ -264,7 +263,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in all, err := a.GetGroupAll() if err != nil { log.Errorf("failed to get group all: %v", err) - all = &Group{} + all = &nbgroup.Group{} } return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { @@ -491,7 +490,7 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { // // Important: Posture checks are applicable only to source group peers, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) { +func getAllPeersFromGroups(account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { peerInGroups := false filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) for _, g := range groups { @@ -512,6 +511,10 @@ func getAllPeersFromGroups(account *Account, groups []string, peerID string, sou continue } + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + if peer.ID == peerID { peerInGroups = true continue diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 681bab1da..1ea3bb379 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" ) @@ -56,7 +57,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -135,16 +136,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, } + validatedPeers := make(map[string]struct{}) + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.getPeerConnectionResources(p.ID) + peers, firewallRules := account.getPeerConnectionResources(p.ID, validatedPeers) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", validatedPeers) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -299,7 +305,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -374,8 +380,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }, } + approvedPeers := make(map[string]struct{}) + for p := range account.Peers { + approvedPeers[p] = struct{}{} + } + t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -403,7 +414,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC") + peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -433,7 +444,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) epectedFirewallRules := []*FirewallRule{ @@ -454,7 +465,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources("peerC") + peers, firewallRules := account.getPeerConnectionResources("peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) epectedFirewallRules := []*FirewallRule{ @@ -569,7 +580,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }, }, - Groups: map[string]*Group{ + Groups: map[string]*nbgroup.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -644,10 +655,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }) + approvedPeers := make(map[string]struct{}) + for p := range account.Peers { + approvedPeers[p] = struct{}{} + } t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -657,7 +672,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC") + peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) expectedFirewallRules := []*FirewallRule{ @@ -673,7 +688,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE") + peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -683,7 +698,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerI") + peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -698,19 +713,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.getPeerConnectionResources("peerB") + peers, firewallRules := account.getPeerConnectionResources("peerB", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.getPeerConnectionResources("peerI") + peers, firewallRules = account.getPeerConnectionResources("peerI", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources("peerC") + peers, firewallRules = account.getPeerConnectionResources("peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -725,14 +740,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources("peerE") + peers, firewallRules = account.getPeerConnectionResources("peerE", approvedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.getPeerConnectionResources("peerA") + peers, firewallRules = account.getPeerConnectionResources("peerA", approvedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) diff --git a/management/server/route_test.go b/management/server/route_test.go index 5a56eaa8b..9f8ea08c9 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/route" ) @@ -858,7 +859,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { groups, err := am.ListGroups(account.Id) require.NoError(t, err) - var groupHA1, groupHA2 *Group + var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { switch group.Name { case routeGroupHA1: @@ -967,7 +968,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") - newGroup := &Group{ + newGroup := &nbgroup.Group{ ID: xid.New().String(), Name: "peer1 group", Peers: []string{peer1ID}, @@ -1014,7 +1015,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return nil, err } eventStore := &activity.InMemoryEventStore{} - return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false) + return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}) } func createRouterStore(t *testing.T) (Store, error) { @@ -1195,7 +1196,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - newGroup := []*Group{ + newGroup := []*nbgroup.Group{ { ID: routeGroup1, Name: routeGroup1, diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 972665527..ff6fb3204 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -339,6 +339,10 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set return nil, err } + if !user.HasAdminPower() && !user.IsServiceUser { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") + } + keys := make([]*SetupKey, 0, len(account.SetupKeys)) for _, key := range account.SetupKeys { var k *SetupKey @@ -368,6 +372,10 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (* return nil, err } + if !user.HasAdminPower() && !user.IsServiceUser { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") + } + var foundKey *SetupKey for _, key := range account.SetupKeys { if key.Id == keyID { diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index c22df2094..43edabbd6 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { @@ -24,7 +25,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -82,7 +83,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -91,7 +92,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(account.Id, userID, &Group{ + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -166,6 +167,37 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } +func TestGetSetupKeys(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + userID := "testingUser" + account, err := manager.GetOrCreateAccountByUser(userID, "") + if err != nil { + t.Fatal(err) + } + + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + ID: "group_1", + Name: "group_name_1", + Peers: []string{}, + }) + if err != nil { + t.Fatal(err) + } + + err = manager.SaveGroup(account.Id, userID, &nbgroup.Group{ + ID: "group_2", + Name: "group_name_2", + Peers: []string{}, + }) + if err != nil { + t.Fatal(err) + } +} + func TestGenerateDefaultSetupKey(t *testing.T) { expectedName := "Default key" expectedRevoke := false diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index f6a6f92a7..e6a9c8467 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -17,6 +17,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" + nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" @@ -64,7 +65,7 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore, sql.SetMaxOpenConns(conns) // TODO: make it configurable err = db.AutoMigrate( - &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, + &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{}, &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, ) @@ -99,17 +100,17 @@ func NewSqliteStoreFromFileStore(filestore *FileStore, dataDir string, metrics t // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { - log.Debugf("acquiring global lock") + log.Tracef("acquiring global lock") start := time.Now() s.globalAccountLock.Lock() unlock = func() { s.globalAccountLock.Unlock() - log.Debugf("released global lock in %v", time.Since(start)) + log.Tracef("released global lock in %v", time.Since(start)) } took := time.Since(start) - log.Debugf("took %v to acquire global lock", took) + log.Tracef("took %v to acquire global lock", took) if s.metrics != nil { s.metrics.StoreMetrics().CountGlobalLockAcquisitionDuration(took) } @@ -118,7 +119,7 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { } func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { - log.Debugf("acquiring lock for account %s", accountID) + log.Tracef("acquiring lock for account %s", accountID) start := time.Now() value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) @@ -127,7 +128,7 @@ func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { unlock = func() { mtx.Unlock() - log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) + log.Tracef("released lock for account %s in %v", accountID, time.Since(start)) } return unlock @@ -434,7 +435,7 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { } account.UsersG = nil - account.Groups = make(map[string]*Group, len(account.GroupsG)) + account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG)) for _, group := range account.GroupsG { account.Groups[group.ID] = group.Copy() } diff --git a/management/server/user.go b/management/server/user.go index f1516139b..b955c4058 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -10,6 +10,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -49,23 +50,6 @@ type UserStatus string // UserRole is the role of a User type UserRole string -// IntegrationReference holds the reference to a particular integration -type IntegrationReference struct { - ID int - IntegrationType string -} - -func (ir IntegrationReference) String() string { - return fmt.Sprintf("%s:%d", ir.IntegrationType, ir.ID) -} - -func (ir IntegrationReference) CacheKey(path ...string) string { - if len(path) == 0 { - return ir.String() - } - return fmt.Sprintf("%s:%s", ir.String(), strings.Join(path, ":")) -} - // User represents a user of the system type User struct { Id string `gorm:"primaryKey"` @@ -91,7 +75,7 @@ type User struct { // Issued of the user Issued string `gorm:"default:api"` - IntegrationReference IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } // IsBlocked returns true if the user is blocked, false otherwise @@ -113,12 +97,20 @@ func (u *User) HasAdminPower() bool { } // ToUserInfo converts a User object to a UserInfo object. -func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { +func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { autoGroups := u.AutoGroups if autoGroups == nil { autoGroups = []string{} } + dashboardViewPermissions := "full" + if !u.HasAdminPower() { + dashboardViewPermissions = "limited" + if settings.RegularUsersViewBlocked { + dashboardViewPermissions = "blocked" + } + } + if userData == nil { return &UserInfo{ ID: u.Id, @@ -131,6 +123,9 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { IsBlocked: u.Blocked, LastLogin: u.LastLogin, Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, }, nil } if userData.ID != u.Id { @@ -153,6 +148,9 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { IsBlocked: u.Blocked, LastLogin: u.LastLogin, Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, }, nil } @@ -358,7 +356,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite am.StoreEvent(userID, newUser.Id, accountID, activity.UserInvited, nil) - return newUser.ToUserInfo(idpUser) + return newUser.ToUserInfo(idpUser, account.Settings) } // GetUser looks up a user by provided authorization claims. @@ -905,9 +903,9 @@ func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string if err != nil { return nil, err } - return newUser.ToUserInfo(userData) + return newUser.ToUserInfo(userData, account.Settings) } - return newUser.ToUserInfo(nil) + return newUser.ToUserInfo(nil, account.Settings) } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist @@ -998,7 +996,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( // if user is not an admin then show only current user and do not show other users continue } - info, err := accountUser.ToUserInfo(nil) + info, err := accountUser.ToUserInfo(nil, account.Settings) if err != nil { return nil, err } @@ -1015,7 +1013,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( var info *UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { - info, err = localUser.ToUserInfo(queriedUser) + info, err = localUser.ToUserInfo(queriedUser, account.Settings) if err != nil { return nil, err } @@ -1024,6 +1022,15 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( if localUser.IsServiceUser { name = localUser.ServiceUserName } + + dashboardViewPermissions := "full" + if !localUser.HasAdminPower() { + dashboardViewPermissions = "limited" + if account.Settings.RegularUsersViewBlocked { + dashboardViewPermissions = "blocked" + } + } + info = &UserInfo{ ID: localUser.Id, Email: "", @@ -1033,6 +1040,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( Status: string(UserStatusActive), IsServiceUser: localUser.IsServiceUser, NonDeletable: localUser.NonDeletable, + Permissions: UserPermissions{DashboardView: dashboardViewPermissions}, } } userInfos = append(userInfos, info) diff --git a/management/server/user_test.go b/management/server/user_test.go index 50cd726ef..c92f87e6c 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" ) @@ -276,7 +277,7 @@ func TestUser_Copy(t *testing.T) { LastLogin: time.Now().UTC(), CreatedAt: time.Now().UTC(), Issued: "test", - IntegrationReference: IntegrationReference{ + IntegrationReference: integration_reference.IntegrationReference{ ID: 0, IntegrationType: "test", }, @@ -603,8 +604,9 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, + Store: store, + eventStore: &activity.InMemoryEventStore{}, + integratedPeerValidator: MocIntegratedValidator{}, } testCases := []struct { @@ -709,6 +711,83 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { assert.Equal(t, 2, regular) } +func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { + testCases := []struct { + name string + role UserRole + limitedViewSettings bool + expectedDashboardPermissions string + }{ + { + name: "Regular user, no limited view settings", + role: UserRoleUser, + limitedViewSettings: false, + expectedDashboardPermissions: "limited", + }, + { + name: "Admin user, no limited view settings", + role: UserRoleAdmin, + limitedViewSettings: false, + expectedDashboardPermissions: "full", + }, + { + name: "Owner, no limited view settings", + role: UserRoleOwner, + limitedViewSettings: false, + expectedDashboardPermissions: "full", + }, + { + name: "Regular user, limited view settings", + role: UserRoleUser, + limitedViewSettings: true, + expectedDashboardPermissions: "blocked", + }, + { + name: "Admin user, limited view settings", + role: UserRoleAdmin, + limitedViewSettings: true, + expectedDashboardPermissions: "full", + }, + { + name: "Owner, limited view settings", + role: UserRoleOwner, + limitedViewSettings: true, + expectedDashboardPermissions: "full", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) + account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings + delete(account.Users, mockUserID) + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + } + + users, err := am.ListUsers(mockAccountID) + if err != nil { + t.Fatalf("Error when checking user role: %s", err) + } + + assert.Equal(t, 1, len(users)) + + userInfo, _ := users[0].ToUserInfo(nil, account.Settings) + assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView) + }) + } + +} + func TestDefaultAccountManager_ExternalCache(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") @@ -716,7 +795,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { Id: "externalUser", Role: UserRoleUser, Issued: UserIssuedIntegration, - IntegrationReference: IntegrationReference{ + IntegrationReference: integration_reference.IntegrationReference{ ID: 1, IntegrationType: "external", }, diff --git a/util/grpc/dialer_linux.go b/util/grpc/dialer.go similarity index 56% rename from util/grpc/dialer_linux.go rename to util/grpc/dialer.go index b29ee4b29..96b2bc32b 100644 --- a/util/grpc/dialer_linux.go +++ b/util/grpc/dialer.go @@ -1,11 +1,10 @@ -//go:build !android - package grpc import ( "context" "net" + log "github.com/sirupsen/logrus" "google.golang.org/grpc" nbnet "github.com/netbirdio/netbird/util/net" @@ -13,6 +12,11 @@ import ( func WithCustomDialer() grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { - return nbnet.NewDialer().DialContext(ctx, "tcp", addr) + conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) + if err != nil { + log.Errorf("Failed to dial: %s", err) + return nil, err + } + return conn, nil }) } diff --git a/util/grpc/dialer_generic.go b/util/grpc/dialer_generic.go deleted file mode 100644 index 1c2285b14..000000000 --- a/util/grpc/dialer_generic.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !linux || android - -package grpc - -import "google.golang.org/grpc" - -func WithCustomDialer() grpc.DialOption { - return grpc.EmptyDialOption{} -} diff --git a/util/net/dialer.go b/util/net/dialer.go new file mode 100644 index 000000000..0786c667e --- /dev/null +++ b/util/net/dialer.go @@ -0,0 +1,21 @@ +package net + +import ( + "net" +) + +// Dialer extends the standard net.Dialer with the ability to execute hooks before +// and after connections. This can be used to bypass the VPN for connections using this dialer. +type Dialer struct { + *net.Dialer +} + +// NewDialer returns a customized net.Dialer with overridden Control method +func NewDialer() *Dialer { + dialer := &Dialer{ + Dialer: &net.Dialer{}, + } + dialer.init() + + return dialer +} diff --git a/util/net/dialer_generic.go b/util/net/dialer_generic.go index a3c3ad67c..4eda710ac 100644 --- a/util/net/dialer_generic.go +++ b/util/net/dialer_generic.go @@ -1,19 +1,163 @@ -//go:build !linux || android +//go:build !android && !ios package net import ( + "context" + "fmt" "net" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" ) -func NewDialer() *net.Dialer { - return &net.Dialer{} +type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error +type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error + +var ( + dialerDialHooksMutex sync.RWMutex + dialerDialHooks []DialerDialHookFunc + dialerCloseHooksMutex sync.RWMutex + dialerCloseHooks []DialerCloseHookFunc +) + +// AddDialerHook allows adding a new hook to be executed before dialing. +func AddDialerHook(hook DialerDialHookFunc) { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = append(dialerDialHooks, hook) +} + +// AddDialerCloseHook allows adding a new hook to be executed on connection close. +func AddDialerCloseHook(hook DialerCloseHookFunc) { + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = append(dialerCloseHooks, hook) +} + +// RemoveDialerHook removes all dialer hooks. +func RemoveDialerHooks() { + dialerDialHooksMutex.Lock() + defer dialerDialHooksMutex.Unlock() + dialerDialHooks = nil + + dialerCloseHooksMutex.Lock() + defer dialerCloseHooksMutex.Unlock() + dialerCloseHooks = nil +} + +// DialContext wraps the net.Dialer's DialContext method to use the custom connection +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + var resolver *net.Resolver + if d.Resolver != nil { + resolver = d.Resolver + } + + connID := GenerateConnID() + if dialerDialHooks != nil { + if err := callDialerHooks(ctx, connID, address, resolver); err != nil { + log.Errorf("Failed to call dialer hooks: %v", err) + } + } + + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("dial: %w", err) + } + + // Wrap the connection in Conn to handle Close with hooks + return &Conn{Conn: conn, ID: connID}, nil +} + +// Dial wraps the net.Dialer's Dial method to use the custom connection +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} + +func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { + host, _, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("split host and port: %w", err) + } + ips, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("failed to resolve address %s: %w", address, err) + } + + log.Debugf("Dialer resolved IPs for %s: %v", address, ips) + + var result *multierror.Error + + dialerDialHooksMutex.RLock() + defer dialerDialHooksMutex.RUnlock() + for _, hook := range dialerDialHooks { + if err := hook(ctx, connID, ips); err != nil { + result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err)) + } + } + + return result.ErrorOrNil() } func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - return net.DialUDP(network, laddr, raddr) + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) + } + + return udpConn, nil } func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - return net.DialTCP(network, laddr, raddr) + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) + } + + return tcpConn, nil } diff --git a/util/net/dialer_linux.go b/util/net/dialer_linux.go index d559490c5..aed5c59a3 100644 --- a/util/net/dialer_linux.go +++ b/util/net/dialer_linux.go @@ -2,59 +2,11 @@ package net -import ( - "context" - "fmt" - "net" - "syscall" +import "syscall" - log "github.com/sirupsen/logrus" -) - -func NewDialer() *net.Dialer { - return &net.Dialer{ - Control: func(network, address string, c syscall.RawConn) error { - return SetRawSocketMark(c) - }, +// init configures the net.Dialer Control function to set the fwmark on the socket +func (d *Dialer) init() { + d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) } } - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.DialContext(context.Background(), network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type") - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.DialContext(context.Background(), network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type") - } - - return tcpConn, nil -} diff --git a/util/net/dialer_mobile.go b/util/net/dialer_mobile.go new file mode 100644 index 000000000..b95aaa973 --- /dev/null +++ b/util/net/dialer_mobile.go @@ -0,0 +1,15 @@ +//go:build android || ios + +package net + +import ( + "net" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + return net.DialUDP(network, laddr, raddr) +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + return net.DialTCP(network, laddr, raddr) +} diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_nonlinux.go new file mode 100644 index 000000000..3254e6d06 --- /dev/null +++ b/util/net/dialer_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux || android + +package net + +func (d *Dialer) init() { +} diff --git a/util/net/listener.go b/util/net/listener.go new file mode 100644 index 000000000..f4d769f58 --- /dev/null +++ b/util/net/listener.go @@ -0,0 +1,21 @@ +package net + +import ( + "net" +) + +// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before +// responding via the socket and after closing. This can be used to bypass the VPN for listeners. +type ListenerConfig struct { + *net.ListenConfig +} + +// NewListener creates a new ListenerConfig instance. +func NewListener() *ListenerConfig { + listener := &ListenerConfig{ + ListenConfig: &net.ListenConfig{}, + } + listener.init() + + return listener +} diff --git a/util/net/listener_generic.go b/util/net/listener_generic.go index 241c744e5..451279e9d 100644 --- a/util/net/listener_generic.go +++ b/util/net/listener_generic.go @@ -1,13 +1,163 @@ -//go:build !linux || android +//go:build !android && !ios package net -import "net" +import ( + "context" + "fmt" + "net" + "sync" -func NewListener() *net.ListenConfig { - return &net.ListenConfig{} + log "github.com/sirupsen/logrus" +) + +// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn. +type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error + +// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. +type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error + +var ( + listenerWriteHooksMutex sync.RWMutex + listenerWriteHooks []ListenerWriteHookFunc + listenerCloseHooksMutex sync.RWMutex + listenerCloseHooks []ListenerCloseHookFunc +) + +// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. +func AddListenerWriteHook(hook ListenerWriteHookFunc) { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = append(listenerWriteHooks, hook) } -func ListenUDP(network string, locAddr *net.UDPAddr) (*net.UDPConn, error) { - return net.ListenUDP(network, locAddr) +// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection. +func AddListenerCloseHook(hook ListenerCloseHookFunc) { + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = append(listenerCloseHooks, hook) +} + +// RemoveListenerHooks removes all dialer hooks. +func RemoveListenerHooks() { + listenerWriteHooksMutex.Lock() + defer listenerWriteHooksMutex.Unlock() + listenerWriteHooks = nil + + listenerCloseHooksMutex.Lock() + defer listenerCloseHooksMutex.Unlock() + listenerCloseHooks = nil +} + +// ListenPacket listens on the network address and returns a PacketConn +// which includes support for write hooks. +func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + pc, err := l.ListenConfig.ListenPacket(ctx, network, address) + if err != nil { + return nil, fmt.Errorf("listen packet: %w", err) + } + connID := GenerateConnID() + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil +} + +// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality. +type PacketConn struct { + net.PacketConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.PacketConn.WriteTo(b, addr) +} + +// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +func (c *PacketConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.PacketConn) +} + +// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality. +type UDPConn struct { + *net.UDPConn + ID ConnectionID + seenAddrs *sync.Map +} + +// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand. +func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + callWriteHooks(c.ID, c.seenAddrs, b, addr) + return c.UDPConn.WriteTo(b, addr) +} + +// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +func (c *UDPConn) Close() error { + c.seenAddrs = &sync.Map{} + return closeConn(c.ID, c.UDPConn) +} + +func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { + // Lookup the address in the seenAddrs map to avoid calling the hooks for every write + if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { + ipStr, _, splitErr := net.SplitHostPort(addr.String()) + if splitErr != nil { + log.Errorf("Error splitting IP address and port: %v", splitErr) + return + } + + ip, err := net.ResolveIPAddr("ip", ipStr) + if err != nil { + log.Errorf("Error resolving IP address: %v", err) + return + } + log.Debugf("Listener resolved IP for %s: %s", addr, ip) + + func() { + listenerWriteHooksMutex.RLock() + defer listenerWriteHooksMutex.RUnlock() + + for _, hook := range listenerWriteHooks { + if err := hook(id, ip, b); err != nil { + log.Errorf("Error executing listener write hook: %v", err) + } + } + }() + } +} + +func closeConn(id ConnectionID, conn net.PacketConn) error { + err := conn.Close() + + listenerCloseHooksMutex.RLock() + defer listenerCloseHooksMutex.RUnlock() + + for _, hook := range listenerCloseHooks { + if err := hook(id, conn); err != nil { + log.Errorf("Error executing listener close hook: %v", err) + } + } + + return err +} + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) { + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil } diff --git a/util/net/listener_linux.go b/util/net/listener_linux.go index 7b9bda97c..8d332160a 100644 --- a/util/net/listener_linux.go +++ b/util/net/listener_linux.go @@ -3,28 +3,12 @@ package net import ( - "context" - "fmt" - "net" "syscall" ) -func NewListener() *net.ListenConfig { - return &net.ListenConfig{ - Control: func(network, address string, c syscall.RawConn) error { - return SetRawSocketMark(c) - }, +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { + return SetRawSocketMark(c) } } - -func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { - pc, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listening on %s:%s with fwmark: %w", network, laddr, err) - } - udpConn, ok := pc.(*net.UDPConn) - if !ok { - return nil, fmt.Errorf("packetConn is not a *net.UDPConn") - } - return udpConn, nil -} diff --git a/util/net/listener_mobile.go b/util/net/listener_mobile.go new file mode 100644 index 000000000..0dbbb360b --- /dev/null +++ b/util/net/listener_mobile.go @@ -0,0 +1,11 @@ +//go:build android || ios + +package net + +import ( + "net" +) + +func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { + return net.ListenUDP(network, laddr) +} diff --git a/util/net/listener_nonlinux.go b/util/net/listener_nonlinux.go new file mode 100644 index 000000000..fb6eadaaa --- /dev/null +++ b/util/net/listener_nonlinux.go @@ -0,0 +1,6 @@ +//go:build !linux || android + +package net + +func (l *ListenerConfig) init() { +} diff --git a/util/net/net.go b/util/net/net.go index 5714e5229..9ea7ae803 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -1,6 +1,17 @@ package net +import "github.com/google/uuid" + const ( // NetbirdFwmark is the fwmark value used by Netbird via wireguard NetbirdFwmark = 0x1BD00 ) + +// ConnectionID provides a globally unique identifier for network connections. +// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook. +type ConnectionID string + +// GenerateConnID generates a unique identifier for each connection. +func GenerateConnID() ConnectionID { + return ConnectionID(uuid.NewString()) +}