Add route selection functionality for CLI and GUI (#1865)

This commit is contained in:
Viktor Liu
2024-04-23 14:42:53 +02:00
committed by GitHub
parent 3477108ce7
commit f51dc13f8c
20 changed files with 1650 additions and 146 deletions

View File

@@ -31,7 +31,7 @@ import (
// RunClient with main logic.
func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil, nil)
}
// RunClientWithProbes runs the client's main logic with probes attached
@@ -43,8 +43,9 @@ func RunClientWithProbes(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
engineChan chan<- *Engine,
) error {
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe, engineChan)
}
// RunClientMobile with main logic on mobile system
@@ -66,7 +67,7 @@ func RunClientMobile(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
}
func RunClientiOS(
@@ -82,7 +83,7 @@ func RunClientiOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
}
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil, nil)
}
func runClient(
@@ -94,6 +95,7 @@ func runClient(
signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
engineChan chan<- *Engine,
) error {
defer func() {
if r := recover(); r != nil {
@@ -243,6 +245,9 @@ func runClient(
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
if engineChan != nil {
engineChan <- engine
}
log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected)
@@ -252,6 +257,10 @@ func runClient(
backOff.Reset()
if engineChan != nil {
engineChan <- nil
}
err = engine.Stop()
if err != nil {
log.Errorf("failed stopping engine %v", err)

View File

@@ -111,6 +111,9 @@ type Engine struct {
// TURNs is a list of STUN servers used by ICE
TURNs []*stun.URI
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes map[string][]*route.Route
cancel context.CancelFunc
ctx context.Context
@@ -216,6 +219,8 @@ func (e *Engine) Stop() error {
return err
}
e.clientRoutes = nil
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.CLose() asynchronously
time.Sleep(500 * time.Millisecond)
@@ -695,11 +700,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update routes, err: %v", err)
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutes = clientRoutes
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
@@ -1229,6 +1237,28 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
}
}
// GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() map[string][]*route.Route {
return e.clientRoutes
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[string][]*route.Route {
routes := make(map[string][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes {
if i := strings.LastIndex(id, "-"); i != -1 {
id = id[:i]
}
routes[id] = v
}
return routes
}
// GetRouteManager returns the route manager
func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -577,10 +578,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}{}
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
input.inputSerial = updateSerial
input.inputRoutes = newRoutes
return testCase.inputErr
return nil, nil, testCase.inputErr
},
}
@@ -597,8 +598,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "routes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "routes should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match")
})
}
}
@@ -742,8 +743,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
return nil, nil, nil
},
}

View File

@@ -14,6 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -28,7 +29,9 @@ var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
// Manager is a route manager interface
type Manager interface {
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
TriggerSelection(map[string][]*route.Route)
GetRouteSelector() *routeselector.RouteSelector
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error
@@ -41,6 +44,7 @@ type DefaultManager struct {
stop context.CancelFunc
mux sync.Mutex
clientNetworks map[string]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter serverRouter
statusRecorder *peer.Status
wgInterface *iface.WGIface
@@ -54,6 +58,7 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
ctx: mCTX,
stop: cancel,
clientNetworks: make(map[string]*clientNetwork),
routeSelector: routeselector.NewRouteSelector(),
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
@@ -117,28 +122,29 @@ func (m *DefaultManager) Stop() {
}
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return m.ctx.Err()
return nil, nil, m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
newServerRoutesMap, newClientRoutesIDMap := m.classifiesRoutes(newRoutes)
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
m.updateClientNetworks(updateSerial, newClientRoutesIDMap)
m.notifier.onNewRoutes(newClientRoutesIDMap)
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.onNewRoutes(filteredClientRoutes)
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return fmt.Errorf("update routes: %w", err)
return nil, nil, fmt.Errorf("update routes: %w", err)
}
}
return nil
return newServerRoutesMap, newClientRoutesIDMap, nil
}
}
@@ -152,16 +158,51 @@ func (m *DefaultManager) InitialRouteRange() []string {
return m.notifier.initialRouteRanges()
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
// GetRouteSelector returns the route selector
func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
return m.routeSelector
}
// GetClientRoutes returns the client routes
func (m *DefaultManager) GetClientRoutes() map[string]*clientNetwork {
return m.clientNetworks
}
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
func (m *DefaultManager) TriggerSelection(networks map[string][]*route.Route) {
m.mux.Lock()
defer m.mux.Unlock()
networks = m.routeSelector.FilterSelected(networks)
m.stopObsoleteClients(networks)
for id, routes := range networks {
if _, found := m.clientNetworks[id]; found {
// don't touch existing client network watchers
continue
}
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.wgInterface, m.statusRecorder, routes[0].Network)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
}
}
// stopObsoleteClients stops the client network watcher for the networks that are not in the new list
func (m *DefaultManager) stopObsoleteClients(networks map[string][]*route.Route) {
for id, client := range m.clientNetworks {
_, found := networks[id]
if !found {
log.Debugf("stopping client network watcher, %s", id)
if _, ok := networks[id]; !ok {
log.Debugf("Stopping client network watcher, %s", id)
client.stop()
delete(m.clientNetworks, id)
}
}
}
func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[string][]*route.Route) {
// removing routes that do not exist as per the update from the Management service.
m.stopObsoleteClients(networks)
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
@@ -178,7 +219,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks map[
}
}
func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route) {
newClientRoutesIDMap := make(map[string][]*route.Route)
newServerRoutesMap := make(map[string]*route.Route)
ownNetworkIDs := make(map[string]bool)
@@ -210,7 +251,7 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
}
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifiesRoutes(initialRoutes)
_, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0)
for _, routes := range crMap {
rs = append(rs, routes...)

View File

@@ -428,11 +428,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
}
if len(testCase.inputInitRoutes) > 0 {
err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
require.NoError(t, err, "should update routes with init routes")
}
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected

View File

@@ -7,14 +7,17 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
// MockManager is the mock instance of a route manager
type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
StopFunc func()
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error)
TriggerSelectionFunc func(map[string][]*route.Route)
GetRouteSelectorFunc func() *routeselector.RouteSelector
StopFunc func()
}
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
@@ -27,11 +30,25 @@ func (m *MockManager) InitialRouteRange() []string {
}
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[string]*route.Route, map[string][]*route.Route, error) {
if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes)
}
return fmt.Errorf("method UpdateRoutes is not implemented")
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
}
func (m *MockManager) TriggerSelection(networks map[string][]*route.Route) {
if m.TriggerSelectionFunc != nil {
m.TriggerSelectionFunc(networks)
}
}
// GetRouteSelector mock implementation of GetRouteSelector from Manager interface
func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
if m.GetRouteSelectorFunc != nil {
return m.GetRouteSelectorFunc()
}
return nil
}
// Start mock implementation of Start from Manager interface

View File

@@ -304,7 +304,10 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
Dst: ipNet,
}
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RouteDel(route); err != nil &&
!errors.Is(err, syscall.ESRCH) &&
!errors.Is(err, syscall.ENOENT) &&
!errors.Is(err, syscall.EAFNOSUPPORT) {
return fmt.Errorf("netlink remove unreachable route: %w", err)
}

View File

@@ -0,0 +1,132 @@
package routeselector
import (
"fmt"
"slices"
"strings"
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
route "github.com/netbirdio/netbird/route"
)
type RouteSelector struct {
selectedRoutes map[string]struct{}
selectAll bool
}
func NewRouteSelector() *RouteSelector {
return &RouteSelector{
selectedRoutes: map[string]struct{}{},
// default selects all routes
selectAll: true,
}
}
// SelectRoutes updates the selected routes based on the provided route IDs.
func (rs *RouteSelector) SelectRoutes(routes []string, appendRoute bool, allRoutes []string) error {
if !appendRoute {
rs.selectedRoutes = map[string]struct{}{}
}
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
rs.selectedRoutes[route] = struct{}{}
}
rs.selectAll = false
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// SelectAllRoutes sets the selector to select all routes.
func (rs *RouteSelector) SelectAllRoutes() {
rs.selectAll = true
rs.selectedRoutes = map[string]struct{}{}
}
// DeselectRoutes removes specific routes from the selection.
// If the selector is in "select all" mode, it will transition to "select specific" mode.
func (rs *RouteSelector) DeselectRoutes(routes []string, allRoutes []string) error {
if rs.selectAll {
rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{}
for _, route := range allRoutes {
rs.selectedRoutes[route] = struct{}{}
}
}
var multiErr *multierror.Error
for _, route := range routes {
if !slices.Contains(allRoutes, route) {
multiErr = multierror.Append(multiErr, fmt.Errorf("route '%s' is not available", route))
continue
}
delete(rs.selectedRoutes, route)
}
if multiErr != nil {
multiErr.ErrorFormat = formatError
}
return multiErr.ErrorOrNil()
}
// DeselectAllRoutes deselects all routes, effectively disabling route selection.
func (rs *RouteSelector) DeselectAllRoutes() {
rs.selectAll = false
rs.selectedRoutes = map[string]struct{}{}
}
// IsSelected checks if a specific route is selected.
func (rs *RouteSelector) IsSelected(routeID string) bool {
if rs.selectAll {
return true
}
_, selected := rs.selectedRoutes[routeID]
return selected
}
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes map[string][]*route.Route) map[string][]*route.Route {
if rs.selectAll {
return maps.Clone(routes)
}
filtered := map[string][]*route.Route{}
for id, rt := range routes {
netID := id
if i := strings.LastIndex(id, "-"); i != -1 {
netID = id[:i]
}
if rs.IsSelected(netID) {
filtered[id] = rt
}
}
return filtered
}
func formatError(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}

View File

@@ -0,0 +1,275 @@
package routeselector_test
import (
"slices"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
func TestRouteSelector_SelectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
selectRoutes []string
append bool
wantSelected []string
wantError bool
}{
{
name: "Select specific routes, initial all selected",
selectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route1", "route2"},
},
{
name: "Select specific routes, initial all deselected",
initialSelected: []string{},
selectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route1", "route2"},
},
{
name: "Select specific routes with initial selection",
initialSelected: []string{"route1"},
selectRoutes: []string{"route2", "route3"},
wantSelected: []string{"route2", "route3"},
},
{
name: "Select non-existing route",
selectRoutes: []string{"route1", "route4"},
wantSelected: []string{"route1"},
wantError: true,
},
{
name: "Append route with initial selection",
initialSelected: []string{"route1"},
selectRoutes: []string{"route2"},
append: true,
wantSelected: []string{"route1", "route2"},
},
{
name: "Append route without initial selection",
selectRoutes: []string{"route2"},
append: true,
wantSelected: []string{"route2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
err := rs.SelectRoutes(tt.selectRoutes, tt.append, allRoutes)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_SelectAllRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
wantSelected []string
}{
{
name: "Initial all selected",
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial all deselected",
initialSelected: []string{},
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial some selected",
initialSelected: []string{"route1"},
wantSelected: []string{"route1", "route2", "route3"},
},
{
name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"},
wantSelected: []string{"route1", "route2", "route3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
rs.SelectAllRoutes()
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_DeselectRoutes(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
deselectRoutes []string
wantSelected []string
wantError bool
}{
{
name: "Deselect specific routes, initial all selected",
deselectRoutes: []string{"route1", "route2"},
wantSelected: []string{"route3"},
},
{
name: "Deselect specific routes, initial all deselected",
initialSelected: []string{},
deselectRoutes: []string{"route1", "route2"},
wantSelected: []string{},
},
{
name: "Deselect specific routes with initial selection",
initialSelected: []string{"route1", "route2"},
deselectRoutes: []string{"route1", "route3"},
wantSelected: []string{"route2"},
},
{
name: "Deselect non-existing route",
initialSelected: []string{"route1", "route2"},
deselectRoutes: []string{"route1", "route4"},
wantSelected: []string{"route2"},
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
err := rs.DeselectRoutes(tt.deselectRoutes, allRoutes)
if tt.wantError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_DeselectAll(t *testing.T) {
allRoutes := []string{"route1", "route2", "route3"}
tests := []struct {
name string
initialSelected []string
wantSelected []string
}{
{
name: "Initial all selected",
wantSelected: []string{},
},
{
name: "Initial all deselected",
initialSelected: []string{},
wantSelected: []string{},
},
{
name: "Initial some selected",
initialSelected: []string{"route1", "route2"},
wantSelected: []string{},
},
{
name: "Initial all selected",
initialSelected: []string{"route1", "route2", "route3"},
wantSelected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rs := routeselector.NewRouteSelector()
if tt.initialSelected != nil {
err := rs.SelectRoutes(tt.initialSelected, false, allRoutes)
require.NoError(t, err)
}
rs.DeselectAllRoutes()
for _, id := range allRoutes {
assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantSelected, id))
}
})
}
}
func TestRouteSelector_IsSelected(t *testing.T) {
rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
require.NoError(t, err)
assert.True(t, rs.IsSelected("route1"))
assert.True(t, rs.IsSelected("route2"))
assert.False(t, rs.IsSelected("route3"))
assert.False(t, rs.IsSelected("route4"))
}
func TestRouteSelector_FilterSelected(t *testing.T) {
rs := routeselector.NewRouteSelector()
err := rs.SelectRoutes([]string{"route1", "route2"}, false, []string{"route1", "route2", "route3"})
require.NoError(t, err)
routes := map[string][]*route.Route{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
"route3-172.16.0.0/12": {},
}
filtered := rs.FilterSelected(routes)
assert.Equal(t, map[string][]*route.Route{
"route1-10.0.0.0/8": {},
"route2-192.168.0.0/16": {},
}, filtered)
}