[client] Tighten allowed domains for dns forwarder (#3978)

This commit is contained in:
Viktor Liu
2025-06-17 14:03:00 +02:00
committed by GitHub
parent 75c1be69cf
commit de7384e8ea
4 changed files with 697 additions and 66 deletions

View File

@@ -144,15 +144,18 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
if len(r.Question) == 0 {
return
}
log.Tracef("received DNS request for domain=%s type=%v class=%v",
logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, "non A/AAAA query")
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return
}
@@ -161,13 +164,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
d.mu.RUnlock()
if peerKey == "" {
d.writeDNSError(w, r, "no current peer key")
d.writeDNSError(w, r, logger, "no current peer key")
return
}
upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil {
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err))
d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err))
return
}
@@ -184,9 +187,9 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
if err != nil {
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
logger.Errorf("failed writing DNS response: %v", err)
}
return
}
@@ -196,34 +199,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
answer = reply.Answer
}
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err)
logger.Errorf("failed writing DNS response: %v", err)
}
}
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS error response: %v", err)
logger.Errorf("failed to write DNS error response: %v", err)
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed writing DNS continue response: %v", err)
logger.Errorf("failed writing DNS continue response: %v", err)
}
}