switching between client to query upstream

This commit is contained in:
Pascal Fischer
2023-11-06 12:32:25 +01:00
parent e193df3bc7
commit 5632d222cc

View File

@@ -39,6 +39,7 @@ type upstreamResolver struct {
reactivatePeriod time.Duration
upstreamTimeout time.Duration
lIP net.IP
lNet *net.IPNet
lName string
iIndex int
@@ -46,18 +47,6 @@ type upstreamResolver struct {
reactivate func()
}
// func newUpstreamResolver(parentCTX context.Context) *upstreamResolver {
// ctx, cancel := context.WithCancel(parentCTX)
// return &upstreamResolver{
// ctx: ctx,
// cancel: cancel,
// upstreamClient: &dns.Client{},
// upstreamTimeout: upstreamTimeout,
// reactivatePeriod: reactivatePeriod,
// failsTillDeact: failsTillDeact,
// }
// }
func getInterfaceIndex(interfaceName string) (int, error) {
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
@@ -72,7 +61,7 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, wgAddr
ctx, cancel := context.WithCancel(parentCTX)
// Specify the local IP address you want to bind to
localIP, _, err := net.ParseCIDR(wgAddr) // Should be our interface IP
localIP, localNet, err := net.ParseCIDR(wgAddr) // Should be our interface IP
if err != nil {
log.Errorf("error while parsing CIDR: %s", err)
}
@@ -90,12 +79,15 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, wgAddr
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
lIP: localIP,
lNet: localNet,
iIndex: localIFaceIndex,
lName: interfaceName,
}
}
func (u *upstreamResolver) getClient() *dns.Client {
// getClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS
func (u *upstreamResolver) getClientPrivate() *dns.Client {
dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{
IP: u.lIP,
@@ -134,7 +126,7 @@ func (u *upstreamResolver) stop() {
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
defer u.checkUpstreamFails()
//log.WithField("question", r.Question[0]).Debug("received an upstream question")
log.WithField("question", r.Question[0]).Trace("received an upstream question")
select {
case <-u.ctx.Done():
@@ -148,13 +140,23 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
t time.Duration
rm *dns.Msg
)
upstreamExchangeClient := u.getClient()
upstreamExchangeClient := &dns.Client{}
if runtime.GOOS != "ios" {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err = upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
cancel()
} else {
log.Debugf("ios upstream resolver: %s", upstream)
upstreamHost, _, err := net.SplitHostPort(upstream)
if err != nil {
log.Errorf("error while parsing upstream host: %s", err)
}
upstreamIP := net.ParseIP(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) {
log.Debugf("using private client to query upstream: %s", upstream)
upstreamExchangeClient = u.getClientPrivate()
}
rm, t, err = upstreamExchangeClient.Exchange(r, upstream)
}
@@ -170,7 +172,12 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
log.Debugf("took %s to query the upstream %s", t, upstream)
if !rm.Response {
log.WithError(err).WithField("upstream", upstream).
Warn("no response from upstream")
}
log.Tracef("took %s to query the upstream %s", t, upstream)
err = w.WriteMsg(rm)
if err != nil {
@@ -201,11 +208,11 @@ func (u *upstreamResolver) checkUpstreamFails() {
case <-u.ctx.Done():
return
default:
//todo test the deactivation logic, it seems to affect the client
//log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
//u.deactivate()
//u.disabled = true
//go u.waitUntilResponse()
// todo test the deactivation logic, it seems to affect the client
// log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
// u.deactivate()
// u.disabled = true
// go u.waitUntilResponse()
}
}