diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index b4110f342..170eaee08 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -71,3 +71,5 @@ func buildRecordKey(name string, class, qType uint16) string { key := fmt.Sprintf("%s_%d_%d", name, class, qType) return key } + +func (d *localResolver) probeAvailability() {} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index ed4116b9d..0739f0542 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -48,3 +48,7 @@ func (m *MockServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { func (m *MockServer) SearchDomains() []string { return make([]string, 0) } + +// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface +func (m *MockServer) ProbeAvailability() { +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 439c27a27..c22672cd0 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -32,6 +32,7 @@ type Server interface { UpdateDNSServer(serial uint64, update nbdns.Config) error OnUpdatedHostDNSServer(strings []string) SearchDomains() []string + ProbeAvailability() } type registeredHandlerMap map[string]handlerWithStop @@ -63,6 +64,7 @@ type DefaultServer struct { type handlerWithStop interface { dns.Handler stop() + probeAvailability() } type muxUpdate struct { @@ -248,6 +250,14 @@ func (s *DefaultServer) SearchDomains() []string { return searchDomains } +// ProbeAvailability tests each upstream group's servers for availability +// and deactivates the group if no server responds +func (s *DefaultServer) ProbeAvailability() { + for _, mux := range s.dnsMuxMap { + mux.probeAvailability() + } +} + func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // is the service should be Disabled, we stop the listener or fake resolver // and proceed with a regular update to clean up the handlers and records @@ -378,6 +388,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam }) } } + return muxUpdates, nil } @@ -488,13 +499,13 @@ func (s *DefaultServer) upstreamCallbacks( } l := log.WithField("nameservers", nsGroup.NameServers) - l.Debug("reactivate temporary Disabled nameserver group") + l.Debug("reactivate temporary disabled nameserver group") if nsGroup.Primary { s.currentConfig.RouteAll = true } if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { - l.WithError(err).Error("reactivate temporary Disabled nameserver group, DNS update apply") + l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") } } return diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index a716e0f24..b6a0f437a 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -19,10 +19,14 @@ const ( failsTillDeact = int32(5) reactivatePeriod = 30 * time.Second upstreamTimeout = 15 * time.Second + probeTimeout = 2 * time.Second ) +const testRecord = "." + type upstreamClient interface { exchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) + exchangeContext(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) } type UpstreamResolver interface { @@ -80,7 +84,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { rm, t, err := u.upstreamClient.exchange(upstream, r) if err != nil { - if err == context.DeadlineExceeded || isTimeout(err) { + if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { log.WithError(err).WithField("upstream", upstream). Warn("got an error while connecting to upstream") continue @@ -134,13 +138,49 @@ func (u *upstreamResolverBase) checkUpstreamFails() { case <-u.ctx.Done(): return default: - // todo test the deactivation logic, it seems to affect the client - if runtime.GOOS != "ios" { - log.Warnf("upstream resolving is Disabled for %v", reactivatePeriod) - u.deactivate() - u.disabled = true - go u.waitUntilResponse() - } + } + + u.disable() +} + +// probeAvailability tests all upstream servers simultaneously and +// disables the resolver if none work +func (u *upstreamResolverBase) probeAvailability() { + u.mutex.Lock() + defer u.mutex.Unlock() + + select { + case <-u.ctx.Done(): + return + default: + } + + var success bool + var mu sync.Mutex + var wg sync.WaitGroup + + for _, upstream := range u.upstreamServers { + upstream := upstream + + wg.Add(1) + go func() { + defer wg.Done() + if err := u.testNameserver(upstream); err != nil { + log.Warnf("probing upstream nameserver %s: %s", upstream, err) + return + } + + mu.Lock() + defer mu.Unlock() + success = true + }() + } + + wg.Wait() + + // didn't find a working upstream server, let's disable and try later + if !success { + u.disable() } } @@ -156,8 +196,6 @@ func (u *upstreamResolverBase) waitUntilResponse() { Clock: backoff.SystemClock, } - r := new(dns.Msg).SetQuestion("netbird.io.", dns.TypeA) - operation := func() error { select { case <-u.ctx.Done(): @@ -165,16 +203,16 @@ func (u *upstreamResolverBase) waitUntilResponse() { default: } - var err error for _, upstream := range u.upstreamServers { - _, _, err = u.upstreamClient.exchange(upstream, r) - - if err == nil { + if err := u.testNameserver(upstream); err != nil { + log.Tracef("upstream check for %s: %s", upstream, err) + } else { + // at least one upstream server is available, stop probing return nil } } - log.Tracef("checking connectivity with upstreams %s failed with error: %s. Retrying in %s", err, u.upstreamServers, exponentialBackOff.NextBackOff()) + log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServers, exponentialBackOff.NextBackOff()) return fmt.Errorf("got an error from upstream check call") } @@ -200,3 +238,27 @@ func isTimeout(err error) bool { } return false } + +func (u *upstreamResolverBase) disable() { + if u.disabled { + return + } + + // todo test the deactivation logic, it seems to affect the client + if runtime.GOOS != "ios" { + log.Warnf("upstream resolving is Disabled for %v", reactivatePeriod) + u.deactivate() + u.disabled = true + go u.waitUntilResponse() + } +} + +func (u *upstreamResolverBase) testNameserver(server string) error { + ctx, cancel := context.WithTimeout(u.ctx, probeTimeout) + defer cancel() + + r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) + + _, _, err := u.upstreamClient.exchangeContext(ctx, server, r) + return err +} diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index 7283efa20..c49d04749 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -41,6 +41,10 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net } func (u *upstreamResolverIOS) exchange(upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { + return u.exchangeContext(context.Background(), upstream, r) +} + +func (u *upstreamResolverIOS) exchangeContext(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { client := &dns.Client{} upstreamHost, _, err := net.SplitHostPort(upstream) if err != nil { @@ -52,7 +56,7 @@ func (u *upstreamResolverIOS) exchange(upstream string, r *dns.Msg) (rm *dns.Msg client = u.getClientPrivate() } - return client.Exchange(r, upstream) + return client.ExchangeContext(ctx, r, upstream) } // getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface diff --git a/client/internal/dns/upstream_nonios.go b/client/internal/dns/upstream_nonios.go index a146f3f98..05ee727c5 100644 --- a/client/internal/dns/upstream_nonios.go +++ b/client/internal/dns/upstream_nonios.go @@ -24,9 +24,13 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, ip net } func (u *upstreamResolverNonIOS) exchange(upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { - upstreamExchangeClient := &dns.Client{} + // default upstream timeout ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) - rm, t, err = upstreamExchangeClient.ExchangeContext(ctx, r, upstream) - cancel() - return rm, t, err + defer cancel() + return u.exchangeContext(ctx, upstream, r) +} + +func (u *upstreamResolverNonIOS) exchangeContext(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { + upstreamExchangeClient := &dns.Client{} + return upstreamExchangeClient.ExchangeContext(ctx, r, upstream) } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index d73e04ce0..7f200e7f7 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -105,8 +105,13 @@ type mockUpstreamResolver struct { err error } -// ExchangeContext mock implementation of ExchangeContext from upstreamResolver +// Exchange mock implementation of Exchangefrom upstreamResolver func (c mockUpstreamResolver) exchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) { + return c.exchangeContext(context.Background(), upstream, r) +} + +// ExchangeContext mock implementation of ExchangeContext from upstreamResolver +func (c mockUpstreamResolver) exchangeContext(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) { return c.r, c.rtt, c.err } diff --git a/client/internal/engine.go b/client/internal/engine.go index bbce1dced..4493c75e8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -682,6 +682,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { log.Errorf("failed to update dns server, err: %v", err) } + // Test received (upstream) servers for availability right away instead of upon usage. + // If no server of a server group responds this will disable the respective handler and retry later. + e.dnsServer.ProbeAvailability() + if e.acl != nil { e.acl.ApplyFiltering(networkMap) }