diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 1230a4e46..5c7cb31fc 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error { return nberrors.FormatErrorOrNil(result) } -func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg { +func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) { if len(query.Question) == 0 { - return nil + return } question := query.Question[0] - logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s", - question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass]) + qname := strings.ToLower(question.Name) - domain := strings.ToLower(question.Name) + logger.Tracef("question: domain=%s type=%s class=%s", + qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass]) resp := query.SetReply(query) network := resutil.NetworkForQtype(question.Qtype) if network == "" { resp.Rcode = dns.RcodeNotImplemented - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - } - return nil + f.writeResponse(logger, w, resp, qname, startTime) + return } - mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) - // query doesn't match any configured domain + mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, ".")) if mostSpecificResId == "" { resp.Rcode = dns.RcodeRefused - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - } - return nil + f.writeResponse(logger, w, resp, qname, startTime) + return } ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) defer cancel() - result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype) + result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype) if result.Err != nil { - f.handleDNSError(ctx, logger, w, question, resp, domain, result) - return nil + f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime) + return } f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries) - resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...) - f.cache.set(domain, question.Qtype, result.IPs) + resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...) + f.cache.set(qname, question.Qtype, result.IPs) - return resp + f.writeResponse(logger, w, resp, qname, startTime) +} + +func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) { + if err := w.WriteMsg(resp); err != nil { + logger.Errorf("failed to write DNS response: %v", err) + return + } + + logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", + qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) +} + +// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation. +type udpResponseWriter struct { + dns.ResponseWriter + query *dns.Msg +} + +func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error { + opt := u.query.IsEdns0() + maxSize := dns.MinMsgSize + if opt != nil { + maxSize = int(opt.UDPSize()) + } + + if resp.Len() > maxSize { + resp.Truncate(maxSize) + } + + return u.ResponseWriter.WriteMsg(resp) } func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { @@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { "dns_id": fmt.Sprintf("%04x", query.Id), }) - resp := f.handleDNSQuery(logger, w, query) - if resp == nil { - return - } - - opt := query.IsEdns0() - maxSize := dns.MinMsgSize - if opt != nil { - // client advertised a larger EDNS0 buffer - maxSize = int(opt.UDPSize()) - } - - // if our response is too big, truncate and set the TC bit - if resp.Len() > maxSize { - resp.Truncate(maxSize) - } - - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - return - } - - logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", - query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) + f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime) } func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { @@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { "dns_id": fmt.Sprintf("%04x", query.Id), }) - resp := f.handleDNSQuery(logger, w, query) - if resp == nil { - return - } - - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - return - } - - logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", - query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) + f.handleDNSQuery(logger, w, query, startTime) } func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { @@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError( resp *dns.Msg, domain string, result resutil.LookupResult, + startTime time.Time, ) { qType := question.Qtype qTypeName := dns.TypeToString[qType] @@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError( // NotFound: cache negative result and respond if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess { f.cache.set(domain, question.Qtype, nil) - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } @@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError( logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...) resp.Rcode = dns.RcodeSuccess - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write cached DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } @@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError( verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType) if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess { resp.Rcode = verifyResult.Rcode - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } } @@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError( // No cache or verification failed. Log with or without the server field for more context. var dnsErr *net.DNSError if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" { - logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err) + logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err) } else { logger.Warnf(errResolveFailed, domain, result.Err) } - // Write final failure response. - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) } // getMatchingEntries retrieves the resource IDs for a given domain. diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 6416c2f21..7325ef8a7 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) + resp := mockWriter.GetLastResponse() if tt.shouldResolve { require.NotNil(t, resp, "Expected response for authorized domain") require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response") @@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { mockFirewall.AssertExpectations(t) mockResolver.AssertExpectations(t) } else { - if resp != nil { - assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, - "Unauthorized domain should not return successful answers") - } + require.NotNil(t, resp, "Expected response") + assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, + "Unauthorized domain should not return successful answers") mockFirewall.AssertNotCalled(t, "UpdateSet") mockResolver.AssertNotCalled(t, "LookupNetIP") } @@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now()) // Verify response + resp := mockWriter.GetLastResponse() if tt.shouldResolve { require.NotNil(t, resp, "Expected response for authorized domain") require.Equal(t, dns.RcodeSuccess, resp.Rcode) require.NotEmpty(t, resp.Answer) - } else if resp != nil { + } else { + require.NotNil(t, resp, "Expected response") assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0, "Unauthorized domain should be refused or have no answers") } @@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { query.SetQuestion("example.com.", dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) // Verify response contains all IPs + resp := mockWriter.GetLastResponse() require.NotNil(t, resp) require.Equal(t, dns.RcodeSuccess, resp.Rcode) require.Len(t, resp.Answer, 3, "Should have 3 answer records") @@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { }, } - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) // Check the response written to the writer require.NotNil(t, writtenResp, "Expected response to be written") @@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { q1 := &dns.Msg{} q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) w1 := &test.MockResponseWriter{} - resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now()) + resp1 := w1.GetLastResponse() require.NotNil(t, resp1) require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Len(t, resp1.Answer, 1) @@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { // Second query: serve from cache after upstream failure q2 := &dns.Msg{} q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) - var writtenResp *dns.Msg - w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2) + w2 := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now()) - require.NotNil(t, writtenResp, "expected response to be written") - require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) - require.Len(t, writtenResp.Answer, 1) + resp2 := w2.GetLastResponse() + require.NotNil(t, resp2, "expected response to be written") + require.Equal(t, dns.RcodeSuccess, resp2.Rcode) + require.Len(t, resp2.Answer, 1) mockResolver.AssertExpectations(t) } @@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { q1 := &dns.Msg{} q1.SetQuestion(mixedQuery+".", dns.TypeA) w1 := &test.MockResponseWriter{} - resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now()) + resp1 := w1.GetLastResponse() require.NotNil(t, resp1) require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Len(t, resp1.Answer, 1) @@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { q2 := &dns.Msg{} q2.SetQuestion("EXAMPLE.COM", dns.TypeA) - var writtenResp *dns.Msg - w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2) + w2 := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now()) - require.NotNil(t, writtenResp) - require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) - require.Len(t, writtenResp.Answer, 1) + resp2 := w2.GetLastResponse() + require.NotNil(t, resp2) + require.Equal(t, dns.RcodeSuccess, resp2.Rcode) + require.Len(t, resp2.Answer, 1) mockResolver.AssertExpectations(t) } @@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { query.SetQuestion("smtp.mail.example.com.", dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) + resp := mockWriter.GetLastResponse() require.NotNil(t, resp) assert.Equal(t, dns.RcodeSuccess, resp.Rcode) @@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { query := &dns.Msg{} query.SetQuestion(dns.Fqdn("example.com"), tt.queryType) - var writtenResp *dns.Msg - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - writtenResp = m - return nil - }, - } + mockWriter := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) - - // If a response was returned, it means it should be written (happens in wrapper functions) - if resp != nil && writtenResp == nil { - writtenResp = resp - } - - require.NotNil(t, writtenResp, "Expected response to be written") - assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description) + resp := mockWriter.GetLastResponse() + require.NotNil(t, resp, "Expected response to be written") + assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description) if tt.expectNoAnswer { - assert.Empty(t, writtenResp.Answer, "Response should have no answer records") + assert.Empty(t, resp.Answer, "Response should have no answer records") } mockResolver.AssertExpectations(t) @@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) { query := &dns.Msg{} // Don't set any question - writeCalled := false - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - writeCalled = true - return nil - }, - } - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + mockWriter := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) - assert.Nil(t, resp, "Should return nil for empty query") - assert.False(t, writeCalled, "Should not write response for empty query") + assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query") }