Apply system routes right away

This commit is contained in:
Viktor Liu
2025-05-31 12:03:07 +02:00
parent 8208a7939c
commit ae01335bfe
2 changed files with 105 additions and 50 deletions

View File

@@ -57,6 +57,15 @@ type RouteHandler interface {
RemoveAllowedIPs() error RemoveAllowedIPs() error
} }
type ClientNetworkConfig struct {
Context context.Context
DNSRouteInterval time.Duration
WGInterface iface.WGIface
StatusRecorder *peer.Status
Route *route.Route
Handler RouteHandler
}
type clientNetwork struct { type clientNetwork struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
@@ -71,40 +80,19 @@ type clientNetwork struct {
updateSerial uint64 updateSerial uint64
} }
func newClientNetworkWatcher( func newClientNetworkWatcher(config ClientNetworkConfig) *clientNetwork {
ctx context.Context, ctx, cancel := context.WithCancel(config.Context)
dnsRouteInterval time.Duration,
wgInterface iface.WGIface,
statusRecorder *peer.Status,
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{ client := &clientNetwork{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
statusRecorder: statusRecorder, statusRecorder: config.StatusRecorder,
wgInterface: wgInterface, wgInterface: config.WGInterface,
routes: make(map[route.ID]*route.Route), routes: make(map[route.ID]*route.Route),
routePeersNotifiers: make(map[string]chan struct{}), routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate), routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}), peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute( handler: config.Handler,
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouteInterval,
statusRecorder,
wgInterface,
dnsServer,
peerStore,
useNewDNSRoute,
),
} }
return client return client
} }

View File

@@ -11,9 +11,11 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
@@ -88,6 +90,7 @@ type DefaultManager struct {
useNewDNSRoute bool useNewDNSRoute bool
disableClientRoutes bool disableClientRoutes bool
disableServerRoutes bool disableServerRoutes bool
activeRoutes map[route.HAUniqueID]RouteHandler
} }
func NewManager(config ManagerConfig) *DefaultManager { func NewManager(config ManagerConfig) *DefaultManager {
@@ -111,6 +114,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
peerStore: config.PeerStore, peerStore: config.PeerStore,
disableClientRoutes: config.DisableClientRoutes, disableClientRoutes: config.DisableClientRoutes,
disableServerRoutes: config.DisableServerRoutes, disableServerRoutes: config.DisableServerRoutes,
activeRoutes: make(map[route.HAUniqueID]RouteHandler),
} }
useNoop := netstack.IsEnabled() || config.DisableClientRoutes useNoop := netstack.IsEnabled() || config.DisableClientRoutes
@@ -265,6 +269,54 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
} }
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
toAdd := make(map[route.HAUniqueID]*route.Route)
toRemove := make(map[route.HAUniqueID]RouteHandler)
for id, routes := range newRoutes {
if len(routes) > 0 {
toAdd[id] = routes[0]
}
}
for id, activeHandler := range m.activeRoutes {
if _, exists := toAdd[id]; exists {
delete(toAdd, id)
} else {
toRemove[id] = activeHandler
}
}
var merr *multierror.Error
for id, handler := range toRemove {
if err := handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
}
delete(m.activeRoutes, id)
}
for id, route := range toAdd {
handler := handlerFromRoute(
route,
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsRouteInterval,
m.statusRecorder,
m.wgInterface,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
if err := handler.AddRoute(m.ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
continue
}
m.activeRoutes[id] = handler
}
return nberrors.FormatErrorOrNil(merr)
}
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error { func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
@@ -281,6 +333,11 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
if !m.disableClientRoutes { if !m.disableClientRoutes {
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
if err := m.updateSystemRoutes(filteredClientRoutes); err != nil {
log.Errorf("Failed to update system routes: %v", err)
}
m.updateClientNetworks(updateSerial, filteredClientRoutes) m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes) m.notifier.OnNewRoutes(filteredClientRoutes)
} }
@@ -341,6 +398,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.notifier.OnNewRoutes(networks) m.notifier.OnNewRoutes(networks)
if err := m.updateSystemRoutes(networks); err != nil {
log.Errorf("failed to update system routes during selection: %v", err)
}
m.stopObsoleteClients(networks) m.stopObsoleteClients(networks)
for id, routes := range networks { for id, routes := range networks {
@@ -349,18 +410,21 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
continue continue
} }
clientNetworkWatcher := newClientNetworkWatcher( handler := m.activeRoutes[id]
m.ctx, if handler == nil {
m.dnsRouteInterval, log.Warnf("no active handler found for route %s", id)
m.wgInterface, continue
m.statusRecorder, }
routes[0],
m.routeRefCounter, config := ClientNetworkConfig{
m.allowedIPsRefCounter, Context: m.ctx,
m.dnsServer, DNSRouteInterval: m.dnsRouteInterval,
m.peerStore, WGInterface: m.wgInterface,
m.useNewDNSRoute, StatusRecorder: m.statusRecorder,
) Route: routes[0],
Handler: handler,
}
clientNetworkWatcher := newClientNetworkWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
@@ -389,18 +453,21 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks { for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id] clientNetworkWatcher, found := m.clientNetworks[id]
if !found { if !found {
clientNetworkWatcher = newClientNetworkWatcher( handler := m.activeRoutes[id]
m.ctx, if handler == nil {
m.dnsRouteInterval, log.Errorf("No active handler found for route %s", id)
m.wgInterface, continue
m.statusRecorder, }
routes[0],
m.routeRefCounter, config := ClientNetworkConfig{
m.allowedIPsRefCounter, Context: m.ctx,
m.dnsServer, DNSRouteInterval: m.dnsRouteInterval,
m.peerStore, WGInterface: m.wgInterface,
m.useNewDNSRoute, StatusRecorder: m.statusRecorder,
) Route: routes[0],
Handler: handler,
}
clientNetworkWatcher = newClientNetworkWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
} }