diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index d94bbe592..0962e8a3a 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -39,6 +39,7 @@ type upstreamResolver struct { reactivatePeriod time.Duration upstreamTimeout time.Duration lIP net.IP + lNet *net.IPNet lName string iIndex int @@ -46,18 +47,6 @@ type upstreamResolver struct { reactivate func() } -// func newUpstreamResolver(parentCTX context.Context) *upstreamResolver { -// ctx, cancel := context.WithCancel(parentCTX) -// return &upstreamResolver{ -// ctx: ctx, -// cancel: cancel, -// upstreamClient: &dns.Client{}, -// upstreamTimeout: upstreamTimeout, -// reactivatePeriod: reactivatePeriod, -// failsTillDeact: failsTillDeact, -// } -// } - func getInterfaceIndex(interfaceName string) (int, error) { iface, err := net.InterfaceByName(interfaceName) if err != nil { @@ -72,7 +61,7 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, wgAddr ctx, cancel := context.WithCancel(parentCTX) // Specify the local IP address you want to bind to - localIP, _, err := net.ParseCIDR(wgAddr) // Should be our interface IP + localIP, localNet, err := net.ParseCIDR(wgAddr) // Should be our interface IP if err != nil { log.Errorf("error while parsing CIDR: %s", err) } @@ -90,12 +79,15 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, wgAddr reactivatePeriod: reactivatePeriod, failsTillDeact: failsTillDeact, lIP: localIP, + lNet: localNet, iIndex: localIFaceIndex, lName: interfaceName, } } -func (u *upstreamResolver) getClient() *dns.Client { +// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface +// This method is needed for iOS +func (u *upstreamResolver) getClientPrivate() *dns.Client { dialer := &net.Dialer{ LocalAddr: &net.UDPAddr{ IP: u.lIP, @@ -134,7 +126,7 @@ func (u *upstreamResolver) stop() { func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { defer u.checkUpstreamFails() - //log.WithField("question", r.Question[0]).Debug("received an upstream question") + log.WithField("question", r.Question[0]).Trace("received an upstream question") select { case <-u.ctx.Done(): @@ -148,13 +140,23 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { t time.Duration rm *dns.Msg ) - upstreamExchangeClient := u.getClient() + + upstreamExchangeClient := &dns.Client{} if runtime.GOOS != "ios" { ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout) rm, t, err = upstreamExchangeClient.ExchangeContext(ctx, r, upstream) cancel() } else { log.Debugf("ios upstream resolver: %s", upstream) + upstreamHost, _, err := net.SplitHostPort(upstream) + if err != nil { + log.Errorf("error while parsing upstream host: %s", err) + } + upstreamIP := net.ParseIP(upstreamHost) + if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { + log.Debugf("using private client to query upstream: %s", upstream) + upstreamExchangeClient = u.getClientPrivate() + } rm, t, err = upstreamExchangeClient.Exchange(r, upstream) } @@ -170,7 +172,12 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - log.Debugf("took %s to query the upstream %s", t, upstream) + if !rm.Response { + log.WithError(err).WithField("upstream", upstream). + Warn("no response from upstream") + } + + log.Tracef("took %s to query the upstream %s", t, upstream) err = w.WriteMsg(rm) if err != nil { @@ -201,11 +208,11 @@ func (u *upstreamResolver) checkUpstreamFails() { case <-u.ctx.Done(): return default: - //todo test the deactivation logic, it seems to affect the client - //log.Warnf("upstream resolving is disabled for %v", reactivatePeriod) - //u.deactivate() - //u.disabled = true - //go u.waitUntilResponse() + // todo test the deactivation logic, it seems to affect the client + // log.Warnf("upstream resolving is disabled for %v", reactivatePeriod) + // u.deactivate() + // u.disabled = true + // go u.waitUntilResponse() } }