mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[client] Add disable system flags (#3153)
This commit is contained in:
@@ -61,6 +61,11 @@ type ConfigInput struct {
|
||||
DNSRouteInterval *time.Duration
|
||||
ClientCertPath string
|
||||
ClientCertKeyPath string
|
||||
|
||||
DisableClientRoutes *bool
|
||||
DisableServerRoutes *bool
|
||||
DisableDNS *bool
|
||||
DisableFirewall *bool
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
@@ -78,6 +83,12 @@ type Config struct {
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed *bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
|
||||
// SSHKey is a private SSH key in a PEM format
|
||||
SSHKey string
|
||||
|
||||
@@ -402,7 +413,46 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
config.DNSRouteInterval = dynamic.DefaultInterval
|
||||
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableClientRoutes != nil && *input.DisableClientRoutes != config.DisableClientRoutes {
|
||||
if *input.DisableClientRoutes {
|
||||
log.Infof("disabling client routes")
|
||||
} else {
|
||||
log.Infof("enabling client routes")
|
||||
}
|
||||
config.DisableClientRoutes = *input.DisableClientRoutes
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableServerRoutes != nil && *input.DisableServerRoutes != config.DisableServerRoutes {
|
||||
if *input.DisableServerRoutes {
|
||||
log.Infof("disabling server routes")
|
||||
} else {
|
||||
log.Infof("enabling server routes")
|
||||
}
|
||||
config.DisableServerRoutes = *input.DisableServerRoutes
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableDNS != nil && *input.DisableDNS != config.DisableDNS {
|
||||
if *input.DisableDNS {
|
||||
log.Infof("disabling DNS configuration")
|
||||
} else {
|
||||
log.Infof("enabling DNS configuration")
|
||||
}
|
||||
config.DisableDNS = *input.DisableDNS
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableFirewall != nil && *input.DisableFirewall != config.DisableFirewall {
|
||||
if *input.DisableFirewall {
|
||||
log.Infof("disabling firewall configuration")
|
||||
} else {
|
||||
log.Infof("enabling firewall configuration")
|
||||
}
|
||||
config.DisableFirewall = *input.DisableFirewall
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.ClientCertKeyPath != "" {
|
||||
|
||||
@@ -415,6 +415,11 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
||||
RosenpassPermissive: config.RosenpassPermissive,
|
||||
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
|
||||
DNSRouteInterval: config.DNSRouteInterval,
|
||||
|
||||
DisableClientRoutes: config.DisableClientRoutes,
|
||||
DisableServerRoutes: config.DisableServerRoutes,
|
||||
DisableDNS: config.DisableDNS,
|
||||
DisableFirewall: config.DisableFirewall,
|
||||
}
|
||||
|
||||
if config.PreSharedKey != "" {
|
||||
|
||||
@@ -102,3 +102,17 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
type noopHostConfigurator struct{}
|
||||
|
||||
func (n noopHostConfigurator) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n noopHostConfigurator) restoreHostDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n noopHostConfigurator) supportCustomPort() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ type registeredHandlerMap map[string]handlerWithStop
|
||||
type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
disableSys bool
|
||||
mux sync.Mutex
|
||||
service service
|
||||
dnsMuxMap registeredHandlerMap
|
||||
@@ -84,7 +85,14 @@ type muxUpdate struct {
|
||||
}
|
||||
|
||||
// NewDefaultServer returns a new dns server
|
||||
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) {
|
||||
func NewDefaultServer(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
customAddress string,
|
||||
statusRecorder *peer.Status,
|
||||
stateManager *statemanager.Manager,
|
||||
disableSys bool,
|
||||
) (*DefaultServer, error) {
|
||||
var addrPort *netip.AddrPort
|
||||
if customAddress != "" {
|
||||
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
|
||||
@@ -101,7 +109,7 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
|
||||
dnsService = newServiceViaListener(wgInterface, addrPort)
|
||||
}
|
||||
|
||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil
|
||||
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil
|
||||
}
|
||||
|
||||
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
|
||||
@@ -112,9 +120,10 @@ func NewDefaultServerPermanentUpstream(
|
||||
config nbdns.Config,
|
||||
listener listener.NetworkChangeListener,
|
||||
statusRecorder *peer.Status,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
log.Debugf("host dns address list is: %v", hostsDnsList)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||
ds.hostsDNSHolder.set(hostsDnsList)
|
||||
ds.permanent = true
|
||||
ds.addHostRootZone()
|
||||
@@ -131,17 +140,26 @@ func NewDefaultServerIos(
|
||||
wgInterface WGIface,
|
||||
iosDnsManager IosDnsManager,
|
||||
statusRecorder *peer.Status,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
|
||||
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
|
||||
ds.iosDnsManager = iosDnsManager
|
||||
return ds
|
||||
}
|
||||
|
||||
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
|
||||
func newDefaultServer(
|
||||
ctx context.Context,
|
||||
wgInterface WGIface,
|
||||
dnsService service,
|
||||
statusRecorder *peer.Status,
|
||||
stateManager *statemanager.Manager,
|
||||
disableSys bool,
|
||||
) *DefaultServer {
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
defaultServer := &DefaultServer{
|
||||
ctx: ctx,
|
||||
ctxCancel: stop,
|
||||
disableSys: disableSys,
|
||||
service: dnsService,
|
||||
handlerChain: NewHandlerChain(),
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
@@ -220,6 +238,13 @@ func (s *DefaultServer) Initialize() (err error) {
|
||||
}
|
||||
|
||||
s.stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
if s.disableSys {
|
||||
log.Info("system DNS is disabled, not setting up host manager")
|
||||
s.hostManager = &noopHostConfigurator{}
|
||||
return nil
|
||||
}
|
||||
|
||||
s.hostManager, err = s.initialize()
|
||||
if err != nil {
|
||||
return fmt.Errorf("initialize: %w", err)
|
||||
@@ -268,47 +293,47 @@ func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
|
||||
|
||||
// UpdateDNSServer processes an update received from the management service
|
||||
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
if s.ctx.Err() != nil {
|
||||
log.Infof("not updating DNS server as context is closed")
|
||||
return s.ctx.Err()
|
||||
default:
|
||||
if serial < s.updateSerial {
|
||||
return fmt.Errorf("not applying dns update, error: "+
|
||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||
}
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
}
|
||||
|
||||
if s.hostManager == nil {
|
||||
return fmt.Errorf("dns service is not initialized yet")
|
||||
}
|
||||
if serial < s.updateSerial {
|
||||
return fmt.Errorf("not applying dns update, error: "+
|
||||
"network update is %d behind the last applied update", s.updateSerial-serial)
|
||||
}
|
||||
|
||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
UseStringer: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||
}
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
|
||||
if s.previousConfigHash == hash {
|
||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||
s.updateSerial = serial
|
||||
return nil
|
||||
}
|
||||
if s.hostManager == nil {
|
||||
return fmt.Errorf("dns service is not initialized yet")
|
||||
}
|
||||
|
||||
if err := s.applyConfiguration(update); err != nil {
|
||||
return fmt.Errorf("apply configuration: %w", err)
|
||||
}
|
||||
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
|
||||
ZeroNil: true,
|
||||
IgnoreZeroValue: true,
|
||||
SlicesAsSets: true,
|
||||
UseStringer: true,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
|
||||
}
|
||||
|
||||
if s.previousConfigHash == hash {
|
||||
log.Debugf("not applying the dns configuration update as there is nothing new")
|
||||
s.updateSerial = serial
|
||||
s.previousConfigHash = hash
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.applyConfiguration(update); err != nil {
|
||||
return fmt.Errorf("apply configuration: %w", err)
|
||||
}
|
||||
|
||||
s.updateSerial = serial
|
||||
s.previousConfigHash = hash
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DefaultServer) SearchDomains() []string {
|
||||
@@ -627,8 +652,11 @@ func (s *DefaultServer) upstreamCallbacks(
|
||||
s.currentConfig.RouteAll = true
|
||||
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
|
||||
}
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
|
||||
if s.hostManager != nil {
|
||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
||||
}
|
||||
}
|
||||
|
||||
s.updateNSState(nsGroup, nil, true)
|
||||
|
||||
@@ -294,7 +294,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
t.Log(err)
|
||||
}
|
||||
}()
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -403,7 +403,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
|
||||
if err != nil {
|
||||
t.Errorf("create DNS server: %v", err)
|
||||
return
|
||||
@@ -498,7 +498,7 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
|
||||
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil, false)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
@@ -633,7 +633,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
|
||||
var dnsList []string
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}, false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -657,7 +657,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
@@ -749,7 +749,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
}
|
||||
defer wgIFace.Close()
|
||||
dnsConfig := nbdns.Config{}
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
|
||||
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
|
||||
err = dnsServer.Initialize()
|
||||
if err != nil {
|
||||
t.Errorf("failed to initialize DNS server: %v", err)
|
||||
|
||||
@@ -83,6 +83,11 @@ func (h *Manager) allowDNSFirewall() error {
|
||||
IsRange: false,
|
||||
Values: []int{ListenPort},
|
||||
}
|
||||
|
||||
if h.firewall == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
|
||||
@@ -108,6 +108,11 @@ type EngineConfig struct {
|
||||
ServerSSHAllowed bool
|
||||
|
||||
DNSRouteInterval time.Duration
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@@ -374,18 +379,20 @@ func (e *Engine) Start() error {
|
||||
}
|
||||
e.dnsServer = dnsServer
|
||||
|
||||
e.routeManager = routemanager.NewManager(
|
||||
e.ctx,
|
||||
e.config.WgPrivateKey.PublicKey().String(),
|
||||
e.config.DNSRouteInterval,
|
||||
e.wgInterface,
|
||||
e.statusRecorder,
|
||||
e.relayManager,
|
||||
initialRoutes,
|
||||
e.stateManager,
|
||||
dnsServer,
|
||||
e.peerStore,
|
||||
)
|
||||
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||
Context: e.ctx,
|
||||
PublicKey: e.config.WgPrivateKey.PublicKey().String(),
|
||||
DNSRouteInterval: e.config.DNSRouteInterval,
|
||||
WGInterface: e.wgInterface,
|
||||
StatusRecorder: e.statusRecorder,
|
||||
RelayManager: e.relayManager,
|
||||
InitialRoutes: initialRoutes,
|
||||
StateManager: e.stateManager,
|
||||
DNSServer: dnsServer,
|
||||
PeerStore: e.peerStore,
|
||||
DisableClientRoutes: e.config.DisableClientRoutes,
|
||||
DisableServerRoutes: e.config.DisableServerRoutes,
|
||||
})
|
||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to initialize route manager: %s", err)
|
||||
@@ -403,13 +410,8 @@ func (e *Engine) Start() error {
|
||||
return fmt.Errorf("create wg interface: %w", err)
|
||||
}
|
||||
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
||||
if err != nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
} else if e.firewall != nil {
|
||||
if err := e.initFirewall(err); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := e.createFirewall(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.udpMux, err = e.wgInterface.Up()
|
||||
@@ -451,7 +453,27 @@ func (e *Engine) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) initFirewall(error) error {
|
||||
func (e *Engine) createFirewall() error {
|
||||
if e.config.DisableFirewall {
|
||||
log.Infof("firewall is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
||||
if err != nil || e.firewall == nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := e.initFirewall(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) initFirewall() error {
|
||||
if e.firewall.IsServerRouteSupported() {
|
||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||
e.close()
|
||||
@@ -1348,6 +1370,7 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
||||
if e.dnsServer != nil {
|
||||
return nil, e.dnsServer, nil
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
routes, dnsConfig, err := e.readInitialSettings()
|
||||
@@ -1361,14 +1384,17 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
|
||||
*dnsConfig,
|
||||
e.mobileDep.NetworkChangeListener,
|
||||
e.statusRecorder,
|
||||
e.config.DisableDNS,
|
||||
)
|
||||
go e.mobileDep.DnsReadyListener.OnReady()
|
||||
return routes, dnsServer, nil
|
||||
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||
return nil, dnsServer, nil
|
||||
|
||||
default:
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager)
|
||||
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -253,7 +253,14 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
},
|
||||
}
|
||||
engine.wgInterface = wgIface
|
||||
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil, nil)
|
||||
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
|
||||
Context: ctx,
|
||||
PublicKey: key.PublicKey().String(),
|
||||
DNSRouteInterval: time.Minute,
|
||||
WGInterface: engine.wgInterface,
|
||||
StatusRecorder: engine.statusRecorder,
|
||||
RelayManager: relayMgr,
|
||||
})
|
||||
_, _, err = engine.routeManager.Init()
|
||||
require.NoError(t, err)
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
|
||||
@@ -48,6 +48,21 @@ type Manager interface {
|
||||
Stop(stateManager *statemanager.Manager)
|
||||
}
|
||||
|
||||
type ManagerConfig struct {
|
||||
Context context.Context
|
||||
PublicKey string
|
||||
DNSRouteInterval time.Duration
|
||||
WGInterface iface.IWGIface
|
||||
StatusRecorder *peer.Status
|
||||
RelayManager *relayClient.Manager
|
||||
InitialRoutes []*route.Route
|
||||
StateManager *statemanager.Manager
|
||||
DNSServer dns.Server
|
||||
PeerStore *peerstore.Store
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
}
|
||||
|
||||
// DefaultManager is the default instance of a route manager
|
||||
type DefaultManager struct {
|
||||
ctx context.Context
|
||||
@@ -55,7 +70,7 @@ type DefaultManager struct {
|
||||
mux sync.Mutex
|
||||
clientNetworks map[route.HAUniqueID]*clientNetwork
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter serverRouter
|
||||
serverRouter *serverRouter
|
||||
sysOps *systemops.SysOps
|
||||
statusRecorder *peer.Status
|
||||
relayMgr *relayClient.Manager
|
||||
@@ -67,48 +82,46 @@ type DefaultManager struct {
|
||||
dnsRouteInterval time.Duration
|
||||
stateManager *statemanager.Manager
|
||||
// clientRoutes is the most recent list of clientRoutes received from the Management Service
|
||||
clientRoutes route.HAMap
|
||||
dnsServer dns.Server
|
||||
peerStore *peerstore.Store
|
||||
useNewDNSRoute bool
|
||||
clientRoutes route.HAMap
|
||||
dnsServer dns.Server
|
||||
peerStore *peerstore.Store
|
||||
useNewDNSRoute bool
|
||||
disableClientRoutes bool
|
||||
disableServerRoutes bool
|
||||
}
|
||||
|
||||
func NewManager(
|
||||
ctx context.Context,
|
||||
pubKey string,
|
||||
dnsRouteInterval time.Duration,
|
||||
wgInterface iface.IWGIface,
|
||||
statusRecorder *peer.Status,
|
||||
relayMgr *relayClient.Manager,
|
||||
initialRoutes []*route.Route,
|
||||
stateManager *statemanager.Manager,
|
||||
dnsServer dns.Server,
|
||||
peerStore *peerstore.Store,
|
||||
) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(ctx)
|
||||
func NewManager(config ManagerConfig) *DefaultManager {
|
||||
mCTX, cancel := context.WithCancel(config.Context)
|
||||
notifier := notifier.NewNotifier()
|
||||
sysOps := systemops.NewSysOps(wgInterface, notifier)
|
||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
||||
|
||||
dm := &DefaultManager{
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
dnsRouteInterval: dnsRouteInterval,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
relayMgr: relayMgr,
|
||||
sysOps: sysOps,
|
||||
statusRecorder: statusRecorder,
|
||||
wgInterface: wgInterface,
|
||||
pubKey: pubKey,
|
||||
notifier: notifier,
|
||||
stateManager: stateManager,
|
||||
dnsServer: dnsServer,
|
||||
peerStore: peerStore,
|
||||
ctx: mCTX,
|
||||
stop: cancel,
|
||||
dnsRouteInterval: config.DNSRouteInterval,
|
||||
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
|
||||
relayMgr: config.RelayManager,
|
||||
sysOps: sysOps,
|
||||
statusRecorder: config.StatusRecorder,
|
||||
wgInterface: config.WGInterface,
|
||||
pubKey: config.PublicKey,
|
||||
notifier: notifier,
|
||||
stateManager: config.StateManager,
|
||||
dnsServer: config.DNSServer,
|
||||
peerStore: config.PeerStore,
|
||||
disableClientRoutes: config.DisableClientRoutes,
|
||||
disableServerRoutes: config.DisableServerRoutes,
|
||||
}
|
||||
|
||||
// don't proceed with client routes if it is disabled
|
||||
if config.DisableClientRoutes {
|
||||
return dm
|
||||
}
|
||||
|
||||
dm.setupRefCounters()
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
cr := dm.initialClientRoutes(initialRoutes)
|
||||
cr := dm.initialClientRoutes(config.InitialRoutes)
|
||||
dm.notifier.SetInitialClientRoutes(cr)
|
||||
}
|
||||
return dm
|
||||
@@ -156,7 +169,7 @@ func (m *DefaultManager) setupRefCounters() {
|
||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
||||
m.routeSelector = m.initSelector()
|
||||
|
||||
if nbnet.CustomRoutingDisabled() {
|
||||
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
@@ -202,6 +215,15 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||
}
|
||||
|
||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||
if m.disableServerRoutes {
|
||||
log.Info("server routes are disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if firewall == nil {
|
||||
return errors.New("firewall manager is not set")
|
||||
}
|
||||
|
||||
var err error
|
||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||
if err != nil {
|
||||
@@ -228,7 +250,7 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
}
|
||||
}
|
||||
|
||||
if !nbnet.CustomRoutingDisabled() {
|
||||
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
||||
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
||||
log.Errorf("Error cleaning up routing: %v", err)
|
||||
} else {
|
||||
@@ -258,9 +280,11 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
||||
|
||||
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
|
||||
|
||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||
if !m.disableClientRoutes {
|
||||
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
|
||||
m.updateClientNetworks(updateSerial, filteredClientRoutes)
|
||||
m.notifier.OnNewRoutes(filteredClientRoutes)
|
||||
}
|
||||
|
||||
if m.serverRouter != nil {
|
||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||
|
||||
@@ -424,7 +424,12 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
|
||||
statusRecorder := peer.NewRecorder("https://mgm")
|
||||
ctx := context.TODO()
|
||||
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil, nil, nil)
|
||||
routeManager := NewManager(ManagerConfig{
|
||||
Context: ctx,
|
||||
PublicKey: localPeerKey,
|
||||
WGInterface: wgInterface,
|
||||
StatusRecorder: statusRecorder,
|
||||
})
|
||||
|
||||
_, _, err = routeManager.Init()
|
||||
|
||||
@@ -450,8 +455,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
||||
|
||||
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
||||
sr := routeManager.serverRouter.(*defaultServerRouter)
|
||||
require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match")
|
||||
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
package routemanager
|
||||
|
||||
import "github.com/netbirdio/netbird/route"
|
||||
|
||||
type serverRouter interface {
|
||||
updateRoutes(map[route.ID]*route.Route) error
|
||||
removeFromServerNetwork(*route.Route) error
|
||||
cleanUp()
|
||||
}
|
||||
@@ -9,8 +9,19 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
|
||||
type serverRouter struct {
|
||||
}
|
||||
|
||||
func (r serverRouter) cleanUp() {
|
||||
}
|
||||
|
||||
func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
|
||||
return nil, fmt.Errorf("server route not supported on this os")
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type defaultServerRouter struct {
|
||||
type serverRouter struct {
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
routes map[route.ID]*route.Route
|
||||
@@ -26,8 +26,8 @@ type defaultServerRouter struct {
|
||||
statusRecorder *peer.Status
|
||||
}
|
||||
|
||||
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
|
||||
return &defaultServerRouter{
|
||||
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
|
||||
return &serverRouter{
|
||||
ctx: ctx,
|
||||
routes: make(map[route.ID]*route.Route),
|
||||
firewall: firewall,
|
||||
@@ -36,7 +36,7 @@ func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall f
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
||||
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
||||
serverRoutesToRemove := make([]route.ID, 0)
|
||||
|
||||
for routeID := range m.routes {
|
||||
@@ -80,74 +80,72 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
|
||||
if m.ctx.Err() != nil {
|
||||
log.Infof("Not removing from server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
routerPair, err := routeToRouterPair(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
|
||||
err = m.firewall.RemoveNatRule(routerPair)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove routing rules: %w", err)
|
||||
}
|
||||
|
||||
delete(m.routes, route.ID)
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
delete(state.Routes, route.Network.String())
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
routerPair, err := routeToRouterPair(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
|
||||
err = m.firewall.RemoveNatRule(routerPair)
|
||||
if err != nil {
|
||||
return fmt.Errorf("remove routing rules: %w", err)
|
||||
}
|
||||
|
||||
delete(m.routes, route.ID)
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
delete(state.Routes, route.Network.String())
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
|
||||
if m.ctx.Err() != nil {
|
||||
log.Infof("Not adding to server network because context is done")
|
||||
return m.ctx.Err()
|
||||
default:
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
routerPair, err := routeToRouterPair(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
|
||||
err = m.firewall.AddNatRule(routerPair)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert routing rules: %w", err)
|
||||
}
|
||||
|
||||
m.routes[route.ID] = route
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
if state.Routes == nil {
|
||||
state.Routes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
routeStr := route.Network.String()
|
||||
if route.IsDynamic() {
|
||||
routeStr = route.Domains.SafeString()
|
||||
}
|
||||
state.Routes[routeStr] = struct{}{}
|
||||
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
routerPair, err := routeToRouterPair(route)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse prefix: %w", err)
|
||||
}
|
||||
|
||||
err = m.firewall.AddNatRule(routerPair)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert routing rules: %w", err)
|
||||
}
|
||||
|
||||
m.routes[route.ID] = route
|
||||
|
||||
state := m.statusRecorder.GetLocalPeerState()
|
||||
if state.Routes == nil {
|
||||
state.Routes = map[string]struct{}{}
|
||||
}
|
||||
|
||||
routeStr := route.Network.String()
|
||||
if route.IsDynamic() {
|
||||
routeStr = route.Domains.SafeString()
|
||||
}
|
||||
state.Routes[routeStr] = struct{}{}
|
||||
|
||||
m.statusRecorder.UpdateLocalPeerState(state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *defaultServerRouter) cleanUp() {
|
||||
func (m *serverRouter) cleanUp() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
for _, r := range m.routes {
|
||||
|
||||
Reference in New Issue
Block a user