mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Add route selection functionality for CLI and GUI (#1865)
This commit is contained in:
@@ -31,7 +31,7 @@ import (
|
||||
|
||||
// RunClient with main logic.
|
||||
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// RunClientWithProbes runs the client's main logic with probes attached
|
||||
@@ -43,8 +43,9 @@ func RunClientWithProbes(
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
engineChan chan<- *Engine,
|
||||
) error {
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
|
||||
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
|
||||
}
|
||||
|
||||
// RunClientMobile with main logic on mobile system
|
||||
@@ -66,7 +67,7 @@ func RunClientMobile(
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func RunClientiOS(
|
||||
@@ -82,7 +83,7 @@ func RunClientiOS(
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
}
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
|
||||
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
func runClient(
|
||||
@@ -94,6 +95,7 @@ func runClient(
|
||||
signalProbe *Probe,
|
||||
relayProbe *Probe,
|
||||
wgProbe *Probe,
|
||||
engineChan chan<- *Engine,
|
||||
) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -243,6 +245,9 @@ func runClient(
|
||||
log.Errorf("error while starting Netbird Connection Engine: %s", err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
if engineChan != nil {
|
||||
engineChan <- engine
|
||||
}
|
||||
|
||||
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
|
||||
state.Set(StatusConnected)
|
||||
@@ -252,6 +257,10 @@ func runClient(
|
||||
|
||||
backOff.Reset()
|
||||
|
||||
if engineChan != nil {
|
||||
engineChan <- nil
|
||||
}
|
||||
|
||||
err = engine.Stop()
|
||||
if err != nil {
|
||||
log.Errorf("failed stopping engine %v", err)
|
||||
|
||||
@@ -111,6 +111,9 @@ type Engine struct {
|
||||
// TURNs is a list of STUN servers used by ICE
|
||||
TURNs []*stun.URI
|
||||
|
||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||
clientRoutes map[string][]*route.Route
|
||||
|
||||
cancel context.CancelFunc
|
||||
|
||||
ctx context.Context
|
||||
@@ -216,6 +219,8 @@ func (e *Engine) Stop() error {
|
||||
return err
|
||||
}
|
||||
|
||||
e.clientRoutes = nil
|
||||
|
||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
||||
// Removing peers happens in the conn.CLose() asynchronously
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -695,11 +700,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
if protoRoutes == nil {
|
||||
protoRoutes = []*mgmProto.Route{}
|
||||
}
|
||||
err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
|
||||
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
|
||||
if err != nil {
|
||||
log.Errorf("failed to update routes, err: %v", err)
|
||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||
}
|
||||
|
||||
e.clientRoutes = clientRoutes
|
||||
|
||||
protoDNSConfig := networkMap.GetDNSConfig()
|
||||
if protoDNSConfig == nil {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
@@ -1229,6 +1237,28 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetClientRoutes returns the current routes from the route map
|
||||
func (e *Engine) GetClientRoutes() map[string][]*route.Route {
|
||||
return e.clientRoutes
|
||||
}
|
||||
|
||||
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
|
||||
func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route {
|
||||
routes := make(map[string][]*route.Route, len(e.clientRoutes))
|
||||
for id, v := range e.clientRoutes {
|
||||
if i := strings.LastIndex(id, "-"); i != -1 {
|
||||
id = id[:i]
|
||||
}
|
||||
routes[id] = v
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// GetRouteManager returns the route manager
|
||||
func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||
return e.routeManager
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
@@ -577,10 +578,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
}{}
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
||||
input.inputSerial = updateSerial
|
||||
input.inputRoutes = newRoutes
|
||||
return testCase.inputErr
|
||||
return nil, nil, testCase.inputErr
|
||||
},
|
||||
}
|
||||
|
||||
@@ -597,8 +598,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
err = engine.updateNetworkMap(testCase.networkMap)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||
assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match")
|
||||
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match")
|
||||
assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match")
|
||||
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -742,8 +743,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
|
||||
mockRouteManager := &routemanager.MockManager{
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
return nil
|
||||
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
||||
return nil, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
@@ -28,7 +29,9 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
// Manager is a route manager interface
|
||||
type Manager interface {
|
||||
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
|
||||
TriggerSelection(map[string][]*route.Route)
|
||||
GetRouteSelector() *routeselector.RouteSelector
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
EnableServerRouter(firewall firewall.Manager) error
|
||||
@@ -41,6 +44,7 @@ type DefaultManager struct {
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
clientNetworks map[string]*clientNetwork
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter serverRouter
|
||||
statusRecorder *peer.Status
|
||||
wgInterface *iface.WGIface
|
||||
@@ -54,6 +58,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
clientNetworks: make(map[string]*clientNetwork),
|
||||
routeSelector: routeselector.NewRouteSelector(),
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
@@ -117,28 +122,29 @@ func (m *DefaultManager) Stop() {
|
||||
}
|
||||
|
||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
log.Infof("not updating routes as context is closed")
|
||||
return m.ctx.Err()
|
||||
return nil, nil, m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes)
|
||||
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
||||
|
||||
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
|
||||
m.notifier.onNewRoutes(newClientRoutesIDMap)
|
||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
m.notifier.onNewRoutes(filteredClientRoutes)
|
||||
|
||||
if m.serverRouter != nil {
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update routes: %w", err)
|
||||
return nil, nil, fmt.Errorf("update routes: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return newServerRoutesMap, newClientRoutesIDMap, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,16 +158,51 @@ func (m *DefaultManager) InitialRouteRange() []string {
|
||||
return m.notifier.initialRouteRanges()
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
// GetRouteSelector returns the route selector
|
||||
func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
return m.routeSelector
|
||||
}
|
||||
|
||||
// GetClientRoutes returns the client routes
|
||||
func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork {
|
||||
return m.clientNetworks
|
||||
}
|
||||
|
||||
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
|
||||
func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
networks = m.routeSelector.FilterSelected(networks)
|
||||
m.stopObsoleteClients(networks)
|
||||
|
||||
for id, routes := range networks {
|
||||
if _, found := m.clientNetworks[id]; found {
|
||||
// don't touch existing client network watchers
|
||||
continue
|
||||
}
|
||||
|
||||
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.peersStateAndUpdateWatcher()
|
||||
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
|
||||
}
|
||||
}
|
||||
|
||||
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
|
||||
func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) {
|
||||
for id, client := range m.clientNetworks {
|
||||
_, found := networks[id]
|
||||
if !found {
|
||||
log.Debugf("stopping client network watcher, %s", id)
|
||||
if _, ok := networks[id]; !ok {
|
||||
log.Debugf("Stopping client network watcher, %s", id)
|
||||
client.stop()
|
||||
delete(m.clientNetworks, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
|
||||
// removing routes that do not exist as per the update from the Management service.
|
||||
m.stopObsoleteClients(networks)
|
||||
|
||||
for id, routes := range networks {
|
||||
clientNetworkWatcher, found := m.clientNetworks[id]
|
||||
@@ -178,7 +219,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
|
||||
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
|
||||
newClientRoutesIDMap := make(map[string][]*route.Route)
|
||||
newServerRoutesMap := make(map[string]*route.Route)
|
||||
ownNetworkIDs := make(map[string]bool)
|
||||
@@ -210,7 +251,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
|
||||
}
|
||||
|
||||
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
|
||||
_, crMap := m.classifiesRoutes(initialRoutes)
|
||||
_, crMap := m.classifyRoutes(initialRoutes)
|
||||
rs := make([]*route.Route, 0)
|
||||
for _, routes := range crMap {
|
||||
rs = append(rs, routes...)
|
||||
|
||||
@@ -428,11 +428,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
}
|
||||
|
||||
if len(testCase.inputInitRoutes) > 0 {
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes with init routes")
|
||||
}
|
||||
|
||||
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||
require.NoError(t, err, "should update routes")
|
||||
|
||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
||||
|
||||
@@ -7,14 +7,17 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/iface"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// MockManager is the mock instance of a route manager
|
||||
type MockManager struct {
|
||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
|
||||
StopFunc func()
|
||||
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
|
||||
TriggerSelectionFunc func(map[string][]*route.Route)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
StopFunc func()
|
||||
}
|
||||
|
||||
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
||||
@@ -27,11 +30,25 @@ func (m *MockManager) InitialRouteRange() []string {
|
||||
}
|
||||
|
||||
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
|
||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
|
||||
if m.UpdateRoutesFunc != nil {
|
||||
return m.UpdateRoutesFunc(updateSerial, newRoutes)
|
||||
}
|
||||
return fmt.Errorf("method UpdateRoutes is not implemented")
|
||||
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
|
||||
}
|
||||
|
||||
func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) {
|
||||
if m.TriggerSelectionFunc != nil {
|
||||
m.TriggerSelectionFunc(networks)
|
||||
}
|
||||
}
|
||||
|
||||
// GetRouteSelector mock implementation of GetRouteSelector from Manager interface
|
||||
func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
|
||||
if m.GetRouteSelectorFunc != nil {
|
||||
return m.GetRouteSelectorFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start mock implementation of Start from Manager interface
|
||||
|
||||
@@ -304,7 +304,10 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
||||
Dst: ipNet,
|
||||
}
|
||||
|
||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
if err := netlink.RouteDel(route); err != nil &&
|
||||
!errors.Is(err, syscall.ESRCH) &&
|
||||
!errors.Is(err, syscall.ENOENT) &&
|
||||
!errors.Is(err, syscall.EAFNOSUPPORT) {
|
||||
return fmt.Errorf("netlink remove unreachable route: %w", err)
|
||||
}
|
||||
|
||||
|
||||
132
client/internal/routeselector/routeselector.go
Normal file
132
client/internal/routeselector/routeselector.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package routeselector
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
route "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type RouteSelector struct {
|
||||
selectedRoutes map[string]struct{}
|
||||
selectAll bool
|
||||
}
|
||||
|
||||
func NewRouteSelector() *RouteSelector {
|
||||
return &RouteSelector{
|
||||
selectedRoutes: map[string]struct{}{},
|
||||
// default selects all routes
|
||||
selectAll: true,
|
||||
}
|
||||
}
|
||||
|
||||
// SelectRoutes updates the selected routes based on the provided route IDs.
|
||||
func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error {
|
||||
if !appendRoute {
|
||||
rs.selectedRoutes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
var multiErr *multierror.Error
|
||||
for _, route := range routes {
|
||||
if !slices.Contains(allRoutes, route) {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
|
||||
continue
|
||||
}
|
||||
|
||||
rs.selectedRoutes[route] = struct{}{}
|
||||
}
|
||||
rs.selectAll = false
|
||||
|
||||
if multiErr != nil {
|
||||
multiErr.ErrorFormat = formatError
|
||||
}
|
||||
|
||||
return multiErr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// SelectAllRoutes sets the selector to select all routes.
|
||||
func (rs *RouteSelector) SelectAllRoutes() {
|
||||
rs.selectAll = true
|
||||
rs.selectedRoutes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
// DeselectRoutes removes specific routes from the selection.
|
||||
// If the selector is in "select all" mode, it will transition to "select specific" mode.
|
||||
func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error {
|
||||
if rs.selectAll {
|
||||
rs.selectAll = false
|
||||
rs.selectedRoutes = map[string]struct{}{}
|
||||
for _, route := range allRoutes {
|
||||
rs.selectedRoutes[route] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var multiErr *multierror.Error
|
||||
|
||||
for _, route := range routes {
|
||||
if !slices.Contains(allRoutes, route) {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
|
||||
continue
|
||||
}
|
||||
delete(rs.selectedRoutes, route)
|
||||
}
|
||||
|
||||
if multiErr != nil {
|
||||
multiErr.ErrorFormat = formatError
|
||||
}
|
||||
|
||||
return multiErr.ErrorOrNil()
|
||||
}
|
||||
|
||||
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
|
||||
func (rs *RouteSelector) DeselectAllRoutes() {
|
||||
rs.selectAll = false
|
||||
rs.selectedRoutes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
// IsSelected checks if a specific route is selected.
|
||||
func (rs *RouteSelector) IsSelected(routeID string) bool {
|
||||
if rs.selectAll {
|
||||
return true
|
||||
}
|
||||
_, selected := rs.selectedRoutes[routeID]
|
||||
return selected
|
||||
}
|
||||
|
||||
// FilterSelected removes unselected routes from the provided map.
|
||||
func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route {
|
||||
if rs.selectAll {
|
||||
return maps.Clone(routes)
|
||||
}
|
||||
|
||||
filtered := map[string][]*route.Route{}
|
||||
for id, rt := range routes {
|
||||
netID := id
|
||||
if i := strings.LastIndex(id, "-"); i != -1 {
|
||||
netID = id[:i]
|
||||
}
|
||||
if rs.IsSelected(netID) {
|
||||
filtered[id] = rt
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
func formatError(es []error) string {
|
||||
if len(es) == 1 {
|
||||
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
|
||||
}
|
||||
|
||||
points := make([]string, len(es))
|
||||
for i, err := range es {
|
||||
points[i] = fmt.Sprintf("* %s", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%d errors occurred:\n\t%s",
|
||||
len(es), strings.Join(points, "\n\t"))
|
||||
}
|
||||
275
client/internal/routeselector/routeselector_test.go
Normal file
275
client/internal/routeselector/routeselector_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package routeselector_test
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestRouteSelector_SelectRoutes(t *testing.T) {
|
||||
allRoutes := []string{"route1", "route2", "route3"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialSelected []string
|
||||
|
||||
selectRoutes []string
|
||||
append bool
|
||||
|
||||
wantSelected []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "Select specific routes, initial all selected",
|
||||
selectRoutes: []string{"route1", "route2"},
|
||||
wantSelected: []string{"route1", "route2"},
|
||||
},
|
||||
{
|
||||
name: "Select specific routes, initial all deselected",
|
||||
initialSelected: []string{},
|
||||
selectRoutes: []string{"route1", "route2"},
|
||||
wantSelected: []string{"route1", "route2"},
|
||||
},
|
||||
{
|
||||
name: "Select specific routes with initial selection",
|
||||
initialSelected: []string{"route1"},
|
||||
selectRoutes: []string{"route2", "route3"},
|
||||
wantSelected: []string{"route2", "route3"},
|
||||
},
|
||||
{
|
||||
name: "Select non-existing route",
|
||||
selectRoutes: []string{"route1", "route4"},
|
||||
wantSelected: []string{"route1"},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "Append route with initial selection",
|
||||
initialSelected: []string{"route1"},
|
||||
selectRoutes: []string{"route2"},
|
||||
append: true,
|
||||
wantSelected: []string{"route1", "route2"},
|
||||
},
|
||||
{
|
||||
name: "Append route without initial selection",
|
||||
selectRoutes: []string{"route2"},
|
||||
append: true,
|
||||
wantSelected: []string{"route2"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
if tt.initialSelected != nil {
|
||||
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes)
|
||||
if tt.wantError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for _, id := range allRoutes {
|
||||
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteSelector_SelectAllRoutes(t *testing.T) {
|
||||
allRoutes := []string{"route1", "route2", "route3"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialSelected []string
|
||||
|
||||
wantSelected []string
|
||||
}{
|
||||
{
|
||||
name: "Initial all selected",
|
||||
wantSelected: []string{"route1", "route2", "route3"},
|
||||
},
|
||||
{
|
||||
name: "Initial all deselected",
|
||||
initialSelected: []string{},
|
||||
wantSelected: []string{"route1", "route2", "route3"},
|
||||
},
|
||||
{
|
||||
name: "Initial some selected",
|
||||
initialSelected: []string{"route1"},
|
||||
wantSelected: []string{"route1", "route2", "route3"},
|
||||
},
|
||||
{
|
||||
name: "Initial all selected",
|
||||
initialSelected: []string{"route1", "route2", "route3"},
|
||||
wantSelected: []string{"route1", "route2", "route3"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
if tt.initialSelected != nil {
|
||||
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
rs.SelectAllRoutes()
|
||||
|
||||
for _, id := range allRoutes {
|
||||
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteSelector_DeselectRoutes(t *testing.T) {
|
||||
allRoutes := []string{"route1", "route2", "route3"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialSelected []string
|
||||
|
||||
deselectRoutes []string
|
||||
|
||||
wantSelected []string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "Deselect specific routes, initial all selected",
|
||||
deselectRoutes: []string{"route1", "route2"},
|
||||
wantSelected: []string{"route3"},
|
||||
},
|
||||
{
|
||||
name: "Deselect specific routes, initial all deselected",
|
||||
initialSelected: []string{},
|
||||
deselectRoutes: []string{"route1", "route2"},
|
||||
wantSelected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Deselect specific routes with initial selection",
|
||||
initialSelected: []string{"route1", "route2"},
|
||||
deselectRoutes: []string{"route1", "route3"},
|
||||
wantSelected: []string{"route2"},
|
||||
},
|
||||
{
|
||||
name: "Deselect non-existing route",
|
||||
initialSelected: []string{"route1", "route2"},
|
||||
deselectRoutes: []string{"route1", "route4"},
|
||||
wantSelected: []string{"route2"},
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
if tt.initialSelected != nil {
|
||||
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := rs.DeselectRoutes(tt.deselectRoutes, allRoutes)
|
||||
if tt.wantError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for _, id := range allRoutes {
|
||||
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteSelector_DeselectAll(t *testing.T) {
|
||||
allRoutes := []string{"route1", "route2", "route3"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialSelected []string
|
||||
|
||||
wantSelected []string
|
||||
}{
|
||||
{
|
||||
name: "Initial all selected",
|
||||
wantSelected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Initial all deselected",
|
||||
initialSelected: []string{},
|
||||
wantSelected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Initial some selected",
|
||||
initialSelected: []string{"route1", "route2"},
|
||||
wantSelected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Initial all selected",
|
||||
initialSelected: []string{"route1", "route2", "route3"},
|
||||
wantSelected: []string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
if tt.initialSelected != nil {
|
||||
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
rs.DeselectAllRoutes()
|
||||
|
||||
for _, id := range allRoutes {
|
||||
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteSelector_IsSelected(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, rs.IsSelected("route1"))
|
||||
assert.True(t, rs.IsSelected("route2"))
|
||||
assert.False(t, rs.IsSelected("route3"))
|
||||
assert.False(t, rs.IsSelected("route4"))
|
||||
}
|
||||
|
||||
func TestRouteSelector_FilterSelected(t *testing.T) {
|
||||
rs := routeselector.NewRouteSelector()
|
||||
|
||||
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
|
||||
require.NoError(t, err)
|
||||
|
||||
routes := map[string][]*route.Route{
|
||||
"route1-10.0.0.0/8": {},
|
||||
"route2-192.168.0.0/16": {},
|
||||
"route3-172.16.0.0/12": {},
|
||||
}
|
||||
|
||||
filtered := rs.FilterSelected(routes)
|
||||
|
||||
assert.Equal(t, map[string][]*route.Route{
|
||||
"route1-10.0.0.0/8": {},
|
||||
"route2-192.168.0.0/16": {},
|
||||
}, filtered)
|
||||
}
|
||||
Reference in New Issue
Block a user