use dns.Client.Exchange

This commit is contained in:
Maycon Santos
2023-11-03 20:35:52 +01:00
parent 64084ca130
commit 65052e5cba

View File

@@ -4,16 +4,14 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"net" "net"
"net/netip" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall" "syscall"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/libp2p/go-netroute"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@@ -40,6 +38,9 @@ type upstreamResolver struct {
mutex sync.Mutex mutex sync.Mutex
reactivatePeriod time.Duration reactivatePeriod time.Duration
upstreamTimeout time.Duration upstreamTimeout time.Duration
lIP net.IP
lName string
iIndex int
deactivate func() deactivate func()
reactivate func() reactivate func()
@@ -60,6 +61,7 @@ type upstreamResolver struct {
func getInterfaceIndex(interfaceName string) (int, error) { func getInterfaceIndex(interfaceName string) (int, error) {
iface, err := net.InterfaceByName(interfaceName) iface, err := net.InterfaceByName(interfaceName)
if err != nil { if err != nil {
log.Errorf("unable to get interface by name error: %s", err)
return 0, err return 0, err
} }
@@ -75,54 +77,52 @@ func newUpstreamResolver(parentCTX context.Context, interfaceName string, wgAddr
log.Errorf("error while parsing CIDR: %s", err) log.Errorf("error while parsing CIDR: %s", err)
} }
index, err := getInterfaceIndex(interfaceName) index, err := getInterfaceIndex(interfaceName)
rand.Seed(time.Now().UnixNano()) log.Debugf("UpstreamResolver interface name: %s, index: %d, ip: %s", interfaceName, index, localIP)
port := rand.Intn(4001) + 1000
log.Debugf("UpstreamResolver interface name: %s, index: %d, ip: %s, port: %d", interfaceName, index, localIP, port)
if err != nil { if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err) log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
} }
localIFaceIndex := index // Should be our interface index localIFaceIndex := index // Should be our interface index
// Create a custom dialer with the LocalAddr set to the desired IP
return &upstreamResolver{
ctx: ctx,
cancel: cancel,
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
lIP: localIP,
iIndex: localIFaceIndex,
lName: interfaceName,
}
}
func (u *upstreamResolver) getClient() *dns.Client {
dialer := &net.Dialer{ dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{ LocalAddr: &net.UDPAddr{
IP: localIP, IP: u.lIP,
Port: port, // Let the OS pick a free port Port: 0, // Let the OS pick a free port
}, },
Timeout: upstreamTimeout,
Control: func(network, address string, c syscall.RawConn) error { Control: func(network, address string, c syscall.RawConn) error {
var operr error var operr error
fn := func(s uintptr) { fn := func(s uintptr) {
operr = syscall.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, localIFaceIndex) operr = unix.SetsockoptInt(int(s), unix.IPPROTO_IP, unix.IP_BOUND_IF, u.iIndex)
} }
if err := c.Control(fn); err != nil { if err := c.Control(fn); err != nil {
return err return err
} }
if operr != nil {
log.Errorf("error while setting socket option: %s", operr)
}
return operr return operr
}, },
} }
// pktConn, err := dialer.Dial("udp", "100.127.136.151:10053")
// if err != nil {
// log.Errorf("error while dialing: %s", err)
//
// } else {
// pktConn.Write([]byte("hello"))
// pktConn.Close()
// }
// Create a new DNS client with the custom dialer
client := &dns.Client{ client := &dns.Client{
Dialer: dialer, Dialer: dialer,
} }
return client
return &upstreamResolver{
ctx: ctx,
cancel: cancel,
upstreamClient: client,
upstreamTimeout: upstreamTimeout,
reactivatePeriod: reactivatePeriod,
failsTillDeact: failsTillDeact,
}
} }
func (u *upstreamResolver) stop() { func (u *upstreamResolver) stop() {
@@ -134,7 +134,7 @@ func (u *upstreamResolver) stop() {
func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
defer u.checkUpstreamFails() defer u.checkUpstreamFails()
log.WithField("question", r.Question[0]).Debug("received an upstream question") //log.WithField("question", r.Question[0]).Debug("received an upstream question")
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
@@ -143,23 +143,20 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
for _, upstream := range u.upstreamServers { for _, upstream := range u.upstreamServers {
log.Debugf("querying the upstream %s", upstream) var (
rr, errR := netroute.New() err error
if errR != nil { t time.Duration
log.Errorf("unable to create networute: %s", errR) rm *dns.Msg
)
upstreamExchangeClient := u.getClient()
if runtime.GOOS != "ios" {
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err = upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
cancel()
} else { } else {
add := netip.MustParseAddrPort(upstream) log.Debugf("ios upstream resolver: %s", upstream)
_, gateway, preferredSrc, errR := rr.Route(add.Addr().AsSlice()) rm, t, err = upstreamExchangeClient.Exchange(r, upstream)
if errR != nil {
log.Errorf("getting routes returned an error: %v", errR)
} else {
log.Infof("upstream %s gateway: %s, preferredSrc: %s", add.Addr(), gateway, preferredSrc)
}
} }
ctx, cancel := context.WithTimeout(u.ctx, u.upstreamTimeout)
rm, t, err := u.upstreamClient.ExchangeContext(ctx, r, upstream)
cancel()
if err != nil { if err != nil {
if err == context.DeadlineExceeded || isTimeout(err) { if err == context.DeadlineExceeded || isTimeout(err) {
@@ -169,7 +166,7 @@ func (u *upstreamResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
u.failsCount.Add(1) u.failsCount.Add(1)
log.WithError(err).WithField("upstream", upstream). log.WithError(err).WithField("upstream", upstream).
Error("got an error while querying the upstream") Error("got other error while querying the upstream")
return return
} }
@@ -204,10 +201,11 @@ func (u *upstreamResolver) checkUpstreamFails() {
case <-u.ctx.Done(): case <-u.ctx.Done():
return return
default: default:
log.Warnf("upstream resolving is disabled for %v", reactivatePeriod) //todo test the deactivation logic, it seems to affect the client
u.deactivate() //log.Warnf("upstream resolving is disabled for %v", reactivatePeriod)
u.disabled = true //u.deactivate()
go u.waitUntilResponse() //u.disabled = true
//go u.waitUntilResponse()
} }
} }