Merge branch 'main' into android-dns-routes

This commit is contained in:
Viktor Liu
2025-06-17 15:08:31 +02:00
12 changed files with 887 additions and 142 deletions

View File

@@ -213,15 +213,18 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
if len(r.Question) == 0 {
return
}
log.Tracef("received DNS request for domain=%s type=%v class=%v",
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, "non A/AAAA query")
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return
}
@@ -230,19 +233,19 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
d.mu.RUnlock()
if peerKey == "" {
d.writeDNSError(w, r, "no current peer key")
d.writeDNSError(w, r, logger, "no current peer key")
return
}
upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil {
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err))
d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err))
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
if err != nil {
d.writeDNSError(w, r, fmt.Sprintf("create DNS client: %v", err))
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
}
@@ -253,9 +256,9 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
if err != nil {
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
logger.Errorf("failed writing DNS response: %v", err)
}
return
}
@@ -265,34 +268,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
answer = reply.Answer
}
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err)
logger.Errorf("failed writing DNS response: %v", err)
}
}
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS error response: %v", err)
logger.Errorf("failed to write DNS error response: %v", err)
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed writing DNS continue response: %v", err)
logger.Errorf("failed writing DNS continue response: %v", err)
}
}