Use DNS route feature flag (#3048)

Co-authored-by: Viktor Liu <viktor@netbird.io>
This commit is contained in:
Zoltan Papp
2024-12-14 16:46:49 +01:00
committed by GitHub
parent c91d7808bf
commit 2fa1433063
5 changed files with 61 additions and 25 deletions

View File

@@ -802,14 +802,17 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.acl.ApplyFiltering(networkMap) e.acl.ApplyFiltering(networkMap)
} }
var dnsRouteFeatureFlag bool
if networkMap.PeerConfig != nil {
dnsRouteFeatureFlag = networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled
}
routedDomains, routes := toRoutes(networkMap.GetRoutes()) routedDomains, routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes); err != nil { if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err) log.Errorf("failed to update clientRoutes, err: %v", err)
} }
// todo: useRoutingPeerDnsResolutionEnabled from network map proto e.updateDNSForwarder(dnsRouteFeatureFlag, routedDomains)
e.updateDNSForwarder(true, routedDomains)
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))

View File

@@ -21,7 +21,11 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const useNewDNSRoute = true const (
handlerTypeDynamic = iota
handlerTypeDomain
handlerTypeStatic
)
type routerPeerStatus struct { type routerPeerStatus struct {
connected bool connected bool
@@ -67,6 +71,7 @@ func newClientNetworkWatcher(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsServer nbdns.Server, dnsServer nbdns.Server,
peerStore *peerstore.Store, peerStore *peerstore.Store,
useNewDNSRoute bool,
) *clientNetwork { ) *clientNetwork {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
@@ -88,6 +93,7 @@ func newClientNetworkWatcher(
wgInterface, wgInterface,
dnsServer, dnsServer,
peerStore, peerStore,
useNewDNSRoute,
), ),
} }
return client return client
@@ -400,18 +406,19 @@ func handlerFromRoute(
wgInterface iface.IWGIface, wgInterface iface.IWGIface,
dnsServer nbdns.Server, dnsServer nbdns.Server,
peerStore *peerstore.Store, peerStore *peerstore.Store,
useNewDNSRoute bool,
) RouteHandler { ) RouteHandler {
if rt.IsDynamic() { switch handlerType(rt, useNewDNSRoute) {
if useNewDNSRoute { case handlerTypeDomain:
return dnsinterceptor.New( return dnsinterceptor.New(
rt, rt,
routeRefCounter, routeRefCounter,
allowedIPsRefCounter, allowedIPsRefCounter,
statusRecorder, statusRecorder,
dnsServer, dnsServer,
peerStore, peerStore,
) )
} case handlerTypeDynamic:
dns := nbdns.NewServiceViaMemory(wgInterface) dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute( return dynamic.NewRoute(
rt, rt,
@@ -422,6 +429,18 @@ func handlerFromRoute(
wgInterface, wgInterface,
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()), fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
) )
default:
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
} }
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) }
func handlerType(rt *route.Route, useNewDNSRoute bool) int {
if !rt.IsDynamic() {
return handlerTypeStatic
}
if useNewDNSRoute {
return handlerTypeDomain
}
return handlerTypeStatic
} }

View File

@@ -36,7 +36,7 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap GetClientRoutes() route.HAMap
@@ -66,9 +66,10 @@ type DefaultManager struct {
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
stateManager *statemanager.Manager stateManager *statemanager.Manager
// clientRoutes is the most recent list of clientRoutes received from the Management Service // clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap clientRoutes route.HAMap
dnsServer dns.Server dnsServer dns.Server
peerStore *peerstore.Store peerStore *peerstore.Store
useNewDNSRoute bool
} }
func NewManager( func NewManager(
@@ -227,7 +228,7 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
} }
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
log.Infof("not updating routes as context is closed") log.Infof("not updating routes as context is closed")
@@ -237,6 +238,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
m.useNewDNSRoute = useNewDNSRoute
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
@@ -318,6 +320,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.allowedIPsRefCounter, m.allowedIPsRefCounter,
m.dnsServer, m.dnsServer,
m.peerStore, m.peerStore,
m.useNewDNSRoute,
) )
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
@@ -347,7 +350,18 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks { for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id] clientNetworkWatcher, found := m.clientNetworks[id]
if !found { if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerStore) clientNetworkWatcher = newClientNetworkWatcher(
m.ctx,
m.dnsRouteInterval,
m.wgInterface,
m.statusRecorder,
routes[0],
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
} }

View File

@@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
} }
if len(testCase.inputInitRoutes) > 0 { if len(testCase.inputInitRoutes) > 0 {
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) _ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
require.NoError(t, err, "should update routes with init routes") require.NoError(t, err, "should update routes with init routes")
} }
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) _ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
require.NoError(t, err, "should update routes") require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected expectedWatchers := testCase.clientNetworkWatchersExpected

View File

@@ -32,7 +32,7 @@ func (m *MockManager) InitialRouteRange() []string {
} }
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface // UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error { func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, b bool) error {
if m.UpdateRoutesFunc != nil { if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes) return m.UpdateRoutesFunc(updateSerial, newRoutes)
} }