diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 693fbfe01..ca5eafa62 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -149,6 +149,7 @@ nfpms: dockers: - image_templates: - netbirdio/netbird:{{ .Version }}-amd64 + - ghcr.io/netbirdio/netbird:{{ .Version }}-amd64 ids: - netbird goarch: amd64 @@ -164,6 +165,7 @@ dockers: - "--label=maintainer=dev@netbird.io" - image_templates: - netbirdio/netbird:{{ .Version }}-arm64v8 + - ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8 ids: - netbird goarch: arm64 diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 33aba0bf5..3cd588630 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -745,9 +745,7 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool { return false } - // Step 1: Check connection tracking FIRST (with original addresses) if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) { - // Step 2: Apply reverse DNAT for established connections translated := m.translateInboundReverse(packetData, d) if translated { // Re-decode after translation diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 22caaa761..7e7e7cc2d 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -11,9 +11,10 @@ import ( ) const ( - PriorityDNSRoute = 100 - PriorityMatchDomain = 50 - PriorityDefault = 1 + PriorityLocal = 100 + PriorityDNSRoute = 75 + PriorityUpstream = 50 + PriorityDefault = 1 ) type SubdomainMatcher interface { diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 5f03e0758..72c0004d5 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { // Setup handlers with different priorities chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault) - chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain) + chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream) chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute) // Create test request @@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { priority int }{ {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, - {pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.example.com.", priority: nbdns.PriorityUpstream}, {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, }, queryDomain: "test.example.com.", @@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { priority int }{ {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, - {pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "test.example.com.", priority: nbdns.PriorityUpstream}, {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, }, queryDomain: "sub.test.example.com.", @@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { // Add handlers in priority order chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute) - chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain) + chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream) chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault) // Create test request @@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, - {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityUpstream}, {"remove", "example.com.", nbdns.PriorityDNSRoute}, }, query: "example.com.", expectedCalls: map[int]bool{ - nbdns.PriorityDNSRoute: false, - nbdns.PriorityMatchDomain: true, + nbdns.PriorityDNSRoute: false, + nbdns.PriorityUpstream: true, }, }, { @@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, - {"add", "example.com.", nbdns.PriorityMatchDomain}, - {"remove", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityUpstream}, + {"remove", "example.com.", nbdns.PriorityUpstream}, }, query: "example.com.", expectedCalls: map[int]bool{ - nbdns.PriorityDNSRoute: true, - nbdns.PriorityMatchDomain: false, + nbdns.PriorityDNSRoute: true, + nbdns.PriorityUpstream: false, }, }, { @@ -378,16 +378,16 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) { priority int }{ {"add", "example.com.", nbdns.PriorityDNSRoute}, - {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityUpstream}, {"add", "example.com.", nbdns.PriorityDefault}, {"remove", "example.com.", nbdns.PriorityDNSRoute}, - {"remove", "example.com.", nbdns.PriorityMatchDomain}, + {"remove", "example.com.", nbdns.PriorityUpstream}, }, query: "example.com.", expectedCalls: map[int]bool{ - nbdns.PriorityDNSRoute: false, - nbdns.PriorityMatchDomain: false, - nbdns.PriorityDefault: true, + nbdns.PriorityDNSRoute: false, + nbdns.PriorityUpstream: false, + nbdns.PriorityDefault: true, }, }, } @@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Add handlers in mixed order chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) - chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) + chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream) // Test 1: Initial state w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} @@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { defaultHandler.Calls = nil // Test 3: Remove middle priority handler - chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) + chain.RemoveHandler(testDomain, nbdns.PriorityUpstream) w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} // Now lowest priority handler (defaultHandler) should be called @@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) { shouldMatch bool }{ {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, - {"example.com.", nbdns.PriorityMatchDomain, false, false}, + {"example.com.", nbdns.PriorityUpstream, false, false}, {"Example.Com.", nbdns.PriorityDNSRoute, false, true}, }, query: "example.com.", @@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, false}, }, query: "sub.example.com.", expectedMatch: "sub.example.com.", @@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, true}, }, query: "sub.example.com.", expectedMatch: "sub.example.com.", @@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, - {"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, true}, + {"add", "test.sub.example.com.", nbdns.PriorityUpstream, false}, + {"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false}, }, query: "test.sub.example.com.", expectedMatch: "sub.example.com.", @@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, false}, {"add", "example.com.", nbdns.PriorityDNSRoute, true}, }, query: "sub.example.com.", @@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) { priority int subdomain bool }{ - {"add", "example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "other.example.com.", nbdns.PriorityMatchDomain, true}, - {"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, + {"add", "example.com.", nbdns.PriorityUpstream, true}, + {"add", "other.example.com.", nbdns.PriorityUpstream, true}, + {"add", "sub.example.com.", nbdns.PriorityUpstream, false}, }, query: "sub.example.com.", expectedMatch: "sub.example.com.", diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 7b845235c..e81aebf98 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) muxUpdates = append(muxUpdates, handlerWrapper{ domain: customZone.Domain, handler: s.localResolver, - priority: PriorityMatchDomain, + priority: PriorityLocal, }) for _, record := range customZone.Records { @@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam groupedNS := groupNSGroupsByDomain(nameServerGroups) for _, domainGroup := range groupedNS { - basePriority := PriorityMatchDomain + basePriority := PriorityUpstream if domainGroup.domain == nbdns.RootZone { basePriority = PriorityDefault } @@ -588,10 +588,14 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts priority := basePriority - i - // Check if we're about to overlap with the next priority tier - if basePriority == PriorityMatchDomain && priority <= PriorityDefault { + // Check if we're about to overlap with the next priority tier. + // This boundary check ensures that the priority of upstream handlers does not conflict + // with the default priority tier. By decrementing the priority for each handler, we avoid + // overlaps, but if the calculated priority falls into the default tier, we skip the remaining + // handlers to maintain the integrity of the priority system. + if basePriority == PriorityUpstream && priority <= PriorityDefault { log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", - domainGroup.domain, PriorityMatchDomain-PriorityDefault) + domainGroup.domain, PriorityUpstream-PriorityDefault) break } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index e55b27910..1cf59fb5b 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -164,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, dummyHandler.ID(): handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityLocal, }, generateDummyHandler(".", nameServers).ID(): handlerWrapper{ domain: nbdns.RootZone, @@ -186,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, initSerial: 0, @@ -210,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ domain: "netbird.io", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, "local-resolver": handlerWrapper{ domain: "netbird.cloud", handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityLocal, }, }, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, @@ -305,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, initSerial: 0, @@ -321,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) { generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ domain: zoneRecords[0].Name, handler: dummyHandler, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, initSerial: 0, @@ -495,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { "id1": handlerWrapper{ domain: zoneRecords[0].Name, handler: &local.Resolver{}, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, } //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} @@ -978,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) { } chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute) - chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain) + chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream) testCases := []struct { name string @@ -1059,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, "upstream-group2": { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, } @@ -1093,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, "upstream-group2": { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, "upstream-other": { domain: "other.com", handler: &mockHandler{ Id: "upstream-other", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, } @@ -1128,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, }, expectedHandlers: map[string]string{ @@ -1146,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, expectedHandlers: map[string]string{ @@ -1164,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group3", }, - priority: PriorityMatchDomain + 1, + priority: PriorityUpstream + 1, }, // Keep existing groups with their original priorities { @@ -1172,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, }, expectedHandlers: map[string]string{ @@ -1199,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, // Add group3 with lowest priority { @@ -1214,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group3", }, - priority: PriorityMatchDomain - 2, + priority: PriorityUpstream - 2, }, }, expectedHandlers: map[string]string{ @@ -1335,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "other.com", handler: &mockHandler{ Id: "upstream-other", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, expectedHandlers: map[string]string{ @@ -1360,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) { handler: &mockHandler{ Id: "upstream-group1", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "example.com", handler: &mockHandler{ Id: "upstream-group2", }, - priority: PriorityMatchDomain - 1, + priority: PriorityUpstream - 1, }, { domain: "other.com", handler: &mockHandler{ Id: "upstream-other", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, { domain: "new.com", handler: &mockHandler{ Id: "upstream-new", }, - priority: PriorityMatchDomain, + priority: PriorityUpstream, }, }, expectedHandlers: map[string]string{ @@ -1791,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) { // Register domains from different handlers with same domain server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute) - server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain) + server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream) // Verify refcount is 2 zoneKey := toZone("shared.example.com") assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice") // Deregister one handler - server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain) + server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream) // Verify refcount is 1 assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler") @@ -1925,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) { } server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault) - server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain) + server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream) assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized") @@ -1945,3 +1945,111 @@ func TestDomainCaseHandling(t *testing.T) { assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") } + +func TestLocalResolverPriorityInServer(t *testing.T) { + server := &DefaultServer{ + ctx: context.Background(), + wgInterface: &mocWGIface{}, + handlerChain: NewHandlerChain(), + localResolver: local.NewResolver(), + service: &mockService{}, + extraDomains: make(map[domain.Domain]int), + } + + config := nbdns.Config{ + ServiceEnable: true, + CustomZones: []nbdns.CustomZone{ + { + Domain: "local.example.com", + Records: []nbdns.SimpleRecord{ + { + Name: "test.local.example.com", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.100", + }, + }, + }, + }, + NameServerGroups: []*nbdns.NameServerGroup{ + { + Domains: []string{"local.example.com"}, // Same domain as local records + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + }, + }, + } + + localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) + assert.NoError(t, err) + + upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups) + assert.NoError(t, err) + + // Verify that local handler has higher priority than upstream for same domain + var localPriority, upstreamPriority int + localFound, upstreamFound := false, false + + for _, update := range localMuxUpdates { + if update.domain == "local.example.com" { + localPriority = update.priority + localFound = true + } + } + + for _, update := range upstreamMuxUpdates { + if update.domain == "local.example.com" { + upstreamPriority = update.priority + upstreamFound = true + } + } + + assert.True(t, localFound, "Local handler should be found") + assert.True(t, upstreamFound, "Upstream handler should be found") + assert.Greater(t, localPriority, upstreamPriority, + "Local handler priority (%d) should be higher than upstream priority (%d)", + localPriority, upstreamPriority) + assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal") + assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream") +} + +func TestLocalResolverPriorityConstants(t *testing.T) { + // Test that priority constants are ordered correctly + assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route") + assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream") + assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default") + + // Test that local resolver uses the correct priority + server := &DefaultServer{ + localResolver: local.NewResolver(), + } + + config := nbdns.Config{ + CustomZones: []nbdns.CustomZone{ + { + Domain: "local.example.com", + Records: []nbdns.SimpleRecord{ + { + Name: "test.local.example.com", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.100", + }, + }, + }, + }, + } + + localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones) + assert.NoError(t, err) + assert.Len(t, localMuxUpdates, 1) + assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") + assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) +} diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 2fbfb3b91..c44d36599 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -2,6 +2,7 @@ package dns import ( "context" + "crypto/rand" "crypto/sha256" "encoding/hex" "errors" @@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() { // ServeDNS handles a DNS request func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + requestID := GenerateRequestID() + logger := log.WithField("request_id", requestID) var err error defer func() { u.checkUpstreamFails(err) }() - log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } select { case <-u.ctx.Done(): - log.Tracef("%s has been stopped", u) + logger.Tracef("%s has been stopped", u) return default: } @@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if err != nil { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { - log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) + logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) continue } - log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) + logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) continue } if rm == nil || !rm.Response { - log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) + logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) continue } u.successCount.Add(1) - log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) + logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) if err = w.WriteMsg(rm); err != nil { - log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) + logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) } // count the fails only if they happen sequentially u.failsCount.Store(0) return } u.failsCount.Add(1) - log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) + logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) m := new(dns.Msg) m.SetRcode(r, dns.RcodeServerFailure) if err := w.WriteMsg(m); err != nil { - log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) + logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) } } @@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u return rm, t, nil } + +func GenerateRequestID() string { + bytes := make([]byte, 4) + _, err := rand.Read(bytes) + if err != nil { + log.Errorf("failed to generate request ID: %v", err) + return "" + } + return hex.EncodeToString(bytes) +} diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 45b479632..506c429cd 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -18,14 +18,20 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/internal/peer" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) const errResolveFailed = "failed to resolve query for domain=%s: %v" const upstreamTimeout = 15 * time.Second +type resolver interface { + LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) +} + +type firewaller interface { + UpdateSet(set firewall.Set, prefixes []netip.Prefix) error +} + type DNSForwarder struct { listenAddress string ttl uint32 @@ -38,16 +44,18 @@ type DNSForwarder struct { mutex sync.RWMutex fwdEntries []*ForwarderEntry - firewall firewall.Manager + firewall firewaller + resolver resolver } -func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, ttl: ttl, firewall: firewall, statusRecorder: statusRecorder, + resolver: net.DefaultResolver, } } @@ -57,14 +65,17 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { // UDP server mux := dns.NewServeMux() f.mux = mux + mux.HandleFunc(".", f.handleDNSQueryUDP) f.dnsServer = &dns.Server{ Addr: f.listenAddress, Net: "udp", Handler: mux, } + // TCP server tcpMux := dns.NewServeMux() f.tcpMux = tcpMux + tcpMux.HandleFunc(".", f.handleDNSQueryTCP) f.tcpServer = &dns.Server{ Addr: f.listenAddress, Net: "tcp", @@ -87,30 +98,13 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { // return the first error we get (e.g. bind failure or shutdown) return <-errCh } + func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() - if f.mux == nil { - log.Debug("DNS mux is nil, skipping domain update") - f.fwdEntries = entries - return - } - - oldDomains := filterDomains(f.fwdEntries) - for _, d := range oldDomains { - f.mux.HandleRemove(d.PunycodeString()) - f.tcpMux.HandleRemove(d.PunycodeString()) - } - - newDomains := filterDomains(entries) - for _, d := range newDomains { - f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP) - f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP) - } - f.fwdEntries = entries - log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) + log.Debugf("Updated DNS forwarder with %d domains", len(entries)) } func (f *DNSForwarder) Close(ctx context.Context) error { @@ -157,22 +151,31 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns return nil } + mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) + // query doesn't match any configured domain + if mostSpecificResId == "" { + resp.Rcode = dns.RcodeRefused + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } + return nil + } + ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) defer cancel() - ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) + ips, err := f.resolver.LookupNetIP(ctx, network, domain) if err != nil { f.handleDNSError(w, query, resp, domain, err) return nil } - f.updateInternalState(domain, ips) + f.updateInternalState(ips, mostSpecificResId, matchingEntries) f.addIPsToResponse(resp, domain, ips) return resp } func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { - resp := f.handleDNSQuery(w, query) if resp == nil { return @@ -206,9 +209,8 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { } } -func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) { +func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { var prefixes []netip.Prefix - mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) if mostSpecificResId != "" { for _, ip := range ips { var prefix netip.Prefix @@ -339,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar return selectedResId, matches } - -// filterDomains returns a list of normalized domains -func filterDomains(entries []*ForwarderEntry) domain.List { - newDomains := make(domain.List, 0, len(entries)) - for _, d := range entries { - if d.Domain == "" { - log.Warn("empty domain in DNS forwarder") - continue - } - newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString()))) - } - return newDomains -} diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index f0829bbbd..d8228c733 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -1,11 +1,21 @@ package dnsfwd import ( + "context" + "fmt" + "net/netip" + "strings" "testing" + "time" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) @@ -13,7 +23,7 @@ import ( func Test_getMatchingEntries(t *testing.T) { testCases := []struct { name string - storedMappings map[string]route.ResID // key: domain pattern, value: resId + storedMappings map[string]route.ResID queryDomain string expectedResId route.ResID }{ @@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) { { name: "Wildcard pattern does not match different domain", storedMappings: map[string]route.ResID{"*.example.com": "res4"}, - queryDomain: "foo.notexample.com", + queryDomain: "foo.example.org", expectedResId: "", }, { @@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) { }) } } + +type MockFirewall struct { + mock.Mock +} + +func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { + args := m.Called(set, prefixes) + return args.Error(0) +} + +type MockResolver struct { + mock.Mock +} + +func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) { + args := m.Called(ctx, network, host) + return args.Get(0).([]netip.Addr), args.Error(1) +} + +func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) { + tests := []struct { + name string + configuredDomain string + queryDomain string + shouldMatch bool + expectedResID route.ResID + description string + }{ + { + name: "exact domain match should be allowed", + configuredDomain: "example.com", + queryDomain: "example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Direct match to configured domain should work", + }, + { + name: "subdomain access should be restricted", + configuredDomain: "example.com", + queryDomain: "mail.example.com", + shouldMatch: false, + expectedResID: "", + description: "Subdomain should not be accessible unless explicitly configured", + }, + { + name: "wildcard should allow subdomains", + configuredDomain: "*.example.com", + queryDomain: "mail.example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Wildcard domains should allow subdomain access", + }, + { + name: "wildcard should allow base domain", + configuredDomain: "*.example.com", + queryDomain: "example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Wildcard should also match the base domain", + }, + { + name: "deep subdomain should be restricted", + configuredDomain: "example.com", + queryDomain: "deep.mail.example.com", + shouldMatch: false, + expectedResID: "", + description: "Deep subdomains should not be accessible", + }, + { + name: "wildcard allows deep subdomains", + configuredDomain: "*.example.com", + queryDomain: "deep.mail.example.com", + shouldMatch: true, + expectedResID: "test-res-id", + description: "Wildcard should allow deep subdomains", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + forwarder := &DNSForwarder{} + + d, err := domain.FromString(tt.configuredDomain) + require.NoError(t, err) + + entries := []*ForwarderEntry{ + { + Domain: d, + ResID: "test-res-id", + }, + } + + forwarder.UpdateDomains(entries) + + resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain) + + if tt.shouldMatch { + assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID") + assert.NotEmpty(t, matchingEntries, "Expected matching entries") + t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain) + } else { + assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match") + assert.Empty(t, matchingEntries, "Expected no matching entries") + t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain) + } + }) + } +} + +func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + tests := []struct { + name string + configuredDomain string + queryDomain string + shouldResolve bool + description string + }{ + { + name: "configured exact domain resolves", + configuredDomain: "example.com", + queryDomain: "example.com", + shouldResolve: true, + description: "Exact match should resolve", + }, + { + name: "unauthorized subdomain blocked", + configuredDomain: "example.com", + queryDomain: "mail.example.com", + shouldResolve: false, + description: "Subdomain should be blocked without wildcard", + }, + { + name: "wildcard allows subdomain", + configuredDomain: "*.example.com", + queryDomain: "mail.example.com", + shouldResolve: true, + description: "Wildcard should allow subdomain", + }, + { + name: "wildcard allows base domain", + configuredDomain: "*.example.com", + queryDomain: "example.com", + shouldResolve: true, + description: "Wildcard should allow base domain", + }, + { + name: "unrelated domain blocked", + configuredDomain: "example.com", + queryDomain: "example.org", + shouldResolve: false, + description: "Unrelated domain should be blocked", + }, + { + name: "deep subdomain blocked", + configuredDomain: "example.com", + queryDomain: "deep.mail.example.com", + shouldResolve: false, + description: "Deep subdomain should be blocked", + }, + { + name: "wildcard allows deep subdomain", + configuredDomain: "*.example.com", + queryDomain: "deep.mail.example.com", + shouldResolve: true, + description: "Wildcard should allow deep subdomain", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + if tt.shouldResolve { + mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil) + + // Mock successful DNS resolution + fakeIP := netip.MustParseAddr("1.2.3.4") + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil) + } + + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + d, err := domain.FromString(tt.configuredDomain) + require.NoError(t, err) + + entries := []*ForwarderEntry{ + { + Domain: d, + ResID: "test-res-id", + Set: firewall.NewDomainSet([]domain.Domain{d}), + }, + } + + forwarder.UpdateDomains(entries) + + query := &dns.Msg{} + query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, query) + + if tt.shouldResolve { + require.NotNil(t, resp, "Expected response for authorized domain") + require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response") + assert.NotEmpty(t, resp.Answer, "Expected DNS answer records") + + time.Sleep(10 * time.Millisecond) + mockFirewall.AssertExpectations(t) + mockResolver.AssertExpectations(t) + } else { + if resp != nil { + assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, + "Unauthorized domain should not return successful answers") + } + mockFirewall.AssertNotCalled(t, "UpdateSet") + mockResolver.AssertNotCalled(t, "LookupNetIP") + } + }) + } +} + +func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { + tests := []struct { + name string + configuredDomains []string + query string + mockIP string + shouldResolve bool + expectedSetCount int // How many sets should be updated + description string + }{ + { + name: "exact domain gets firewall update", + configuredDomains: []string{"example.com"}, + query: "example.com", + mockIP: "1.1.1.1", + shouldResolve: true, + expectedSetCount: 1, + description: "Single exact match updates one set", + }, + { + name: "wildcard domain gets firewall update", + configuredDomains: []string{"*.example.com"}, + query: "mail.example.com", + mockIP: "1.1.1.2", + shouldResolve: true, + expectedSetCount: 1, + description: "Wildcard match updates one set", + }, + { + name: "overlapping exact and wildcard both get updates", + configuredDomains: []string{"*.example.com", "mail.example.com"}, + query: "mail.example.com", + mockIP: "1.1.1.3", + shouldResolve: true, + expectedSetCount: 2, + description: "Both exact and wildcard sets should be updated", + }, + { + name: "unauthorized domain gets no firewall update", + configuredDomains: []string{"example.com"}, + query: "mail.example.com", + mockIP: "1.1.1.4", + shouldResolve: false, + expectedSetCount: 0, + description: "No firewall update for unauthorized domains", + }, + { + name: "multiple wildcards matching get all updated", + configuredDomains: []string{"*.example.com", "*.sub.example.com"}, + query: "test.sub.example.com", + mockIP: "1.1.1.5", + shouldResolve: true, + expectedSetCount: 2, + description: "All matching wildcard sets should be updated", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + // Set up forwarder + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + // Create entries and track sets + var entries []*ForwarderEntry + sets := make([]firewall.Set, 0) + + for i, configDomain := range tt.configuredDomains { + d, err := domain.FromString(configDomain) + require.NoError(t, err) + + set := firewall.NewDomainSet([]domain.Domain{d}) + sets = append(sets, set) + + entries = append(entries, &ForwarderEntry{ + Domain: d, + ResID: route.ResID(fmt.Sprintf("res-%d", i)), + Set: set, + }) + } + + forwarder.UpdateDomains(entries) + + // Set up mocks + if tt.shouldResolve { + fakeIP := netip.MustParseAddr(tt.mockIP) + mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)). + Return([]netip.Addr{fakeIP}, nil).Once() + + expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)} + + // Count how many sets should actually match + updateCount := 0 + for i, entry := range entries { + domain := strings.ToLower(tt.query) + pattern := entry.Domain.PunycodeString() + + matches := false + if strings.HasPrefix(pattern, "*.") { + baseDomain := strings.TrimPrefix(pattern, "*.") + if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) { + matches = true + } + } else if domain == pattern { + matches = true + } + + if matches { + mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once() + updateCount++ + } + } + + assert.Equal(t, tt.expectedSetCount, updateCount, + "Expected %d sets to be updated, but mock expects %d", + tt.expectedSetCount, updateCount) + } + + // Execute query + dnsQuery := &dns.Msg{} + dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, dnsQuery) + + // Verify response + if tt.shouldResolve { + require.NotNil(t, resp, "Expected response for authorized domain") + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.NotEmpty(t, resp.Answer) + } else if resp != nil { + assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0, + "Unauthorized domain should be refused or have no answers") + } + + // Verify all mock expectations were met + mockFirewall.AssertExpectations(t) + mockResolver.AssertExpectations(t) + }) + } +} + +// Test to verify that multiple IPs for one domain result in all prefixes being sent together +func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + // Configure a single domain + d, err := domain.FromString("example.com") + require.NoError(t, err) + + set := firewall.NewDomainSet([]domain.Domain{d}) + entries := []*ForwarderEntry{{ + Domain: d, + ResID: "test-res", + Set: set, + }} + + forwarder.UpdateDomains(entries) + + // Mock resolver returns multiple IPs + ips := []netip.Addr{ + netip.MustParseAddr("1.1.1.1"), + netip.MustParseAddr("1.1.1.2"), + netip.MustParseAddr("1.1.1.3"), + } + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com."). + Return(ips, nil).Once() + + // Expect ONE UpdateSet call with ALL prefixes + expectedPrefixes := []netip.Prefix{ + netip.PrefixFrom(ips[0], 32), + netip.PrefixFrom(ips[1], 32), + netip.PrefixFrom(ips[2], 32), + } + mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once() + + // Execute query + query := &dns.Msg{} + query.SetQuestion("example.com.", dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, query) + + // Verify response contains all IPs + require.NotNil(t, resp) + require.Equal(t, dns.RcodeSuccess, resp.Rcode) + require.Len(t, resp.Answer, 3, "Should have 3 answer records") + + // Verify mocks + mockFirewall.AssertExpectations(t) + mockResolver.AssertExpectations(t) +} + +func TestDNSForwarder_ResponseCodes(t *testing.T) { + tests := []struct { + name string + queryType uint16 + queryDomain string + configured string + expectedCode int + description string + }{ + { + name: "unauthorized domain returns REFUSED", + queryType: dns.TypeA, + queryDomain: "evil.com", + configured: "example.com", + expectedCode: dns.RcodeRefused, + description: "RFC compliant REFUSED for unauthorized queries", + }, + { + name: "unsupported query type returns NOTIMP", + queryType: dns.TypeMX, + queryDomain: "example.com", + configured: "example.com", + expectedCode: dns.RcodeNotImplemented, + description: "RFC compliant NOTIMP for unsupported types", + }, + { + name: "CNAME query returns NOTIMP", + queryType: dns.TypeCNAME, + queryDomain: "example.com", + configured: "example.com", + expectedCode: dns.RcodeNotImplemented, + description: "CNAME queries not supported", + }, + { + name: "TXT query returns NOTIMP", + queryType: dns.TypeTXT, + queryDomain: "example.com", + configured: "example.com", + expectedCode: dns.RcodeNotImplemented, + description: "TXT queries not supported", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + + d, err := domain.FromString(tt.configured) + require.NoError(t, err) + + entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}} + forwarder.UpdateDomains(entries) + + query := &dns.Msg{} + query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType) + + // Capture the written response + var writtenResp *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + writtenResp = m + return nil + }, + } + + _ = forwarder.handleDNSQuery(mockWriter, query) + + // Check the response written to the writer + require.NotNil(t, writtenResp, "Expected response to be written") + assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description) + }) + } +} + +func TestDNSForwarder_TCPTruncation(t *testing.T) { + // Test that large UDP responses are truncated with TC bit set + mockResolver := &MockResolver{} + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder.resolver = mockResolver + + d, _ := domain.FromString("example.com") + entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}} + forwarder.UpdateDomains(entries) + + // Mock many IPs to create a large response + var manyIPs []netip.Addr + for i := 0; i < 100; i++ { + manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256))) + } + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil) + + // Query without EDNS0 + query := &dns.Msg{} + query.SetQuestion("example.com.", dns.TypeA) + + var writtenResp *dns.Msg + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + writtenResp = m + return nil + }, + } + forwarder.handleDNSQueryUDP(mockWriter, query) + + require.NotNil(t, writtenResp) + assert.True(t, writtenResp.Truncated, "Large response should be truncated") + assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size") +} + +func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { + // Test complex overlapping pattern scenarios + mockFirewall := &MockFirewall{} + mockResolver := &MockResolver{} + + forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder.resolver = mockResolver + + // Set up complex overlapping patterns + patterns := []string{ + "*.example.com", // Matches all subdomains + "*.mail.example.com", // More specific wildcard + "smtp.mail.example.com", // Exact match + "example.com", // Base domain + } + + var entries []*ForwarderEntry + sets := make(map[string]firewall.Set) + + for _, pattern := range patterns { + d, _ := domain.FromString(pattern) + set := firewall.NewDomainSet([]domain.Domain{d}) + sets[pattern] = set + entries = append(entries, &ForwarderEntry{ + Domain: d, + ResID: route.ResID("res-" + pattern), + Set: set, + }) + } + + forwarder.UpdateDomains(entries) + + // Test smtp.mail.example.com - should match 3 patterns + fakeIP := netip.MustParseAddr("1.2.3.4") + mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil) + + expectedPrefix := netip.PrefixFrom(fakeIP, 32) + // All three matching patterns should get firewall updates + mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil) + mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil) + mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil) + + query := &dns.Msg{} + query.SetQuestion("smtp.mail.example.com.", dns.TypeA) + + mockWriter := &test.MockResponseWriter{} + resp := forwarder.handleDNSQuery(mockWriter, query) + + require.NotNil(t, resp) + assert.Equal(t, dns.RcodeSuccess, resp.Rcode) + + // Verify all three sets were updated + mockFirewall.AssertExpectations(t) + + // Verify the most specific ResID was selected + // (exact match should win over wildcards) + resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com") + assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID) + assert.Len(t, matches, 3, "Should match 3 patterns") +} + +func TestDNSForwarder_EmptyQuery(t *testing.T) { + // Test handling of malformed query with no questions + forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + + query := &dns.Msg{} + // Don't set any question + + writeCalled := false + mockWriter := &test.MockResponseWriter{ + WriteMsgFunc: func(m *dns.Msg) error { + writeCalled = true + return nil + }, + } + resp := forwarder.handleDNSQuery(mockWriter, query) + + assert.Nil(t, resp, "Should return nil for empty query") + assert.False(t, writeCalled, "Should not write response for empty query") +} diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 629afec9b..e290ef75f 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -575,13 +575,12 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error { // FinishPeerListModifications this event invoke the notification func (d *Status) FinishPeerListModifications() { d.mux.Lock() + defer d.mux.Unlock() if !d.peerListChangedForNotification { - d.mux.Unlock() return } d.peerListChangedForNotification = false - d.mux.Unlock() d.notifyPeerListChanged() diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index bd44ecb15..c7c3aeb0b 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -213,15 +213,18 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error { // ServeDNS implements the dns.Handler interface func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + requestID := nbdns.GenerateRequestID() + logger := log.WithField("request_id", requestID) + if len(r.Question) == 0 { return } - log.Tracef("received DNS request for domain=%s type=%v class=%v", + logger.Tracef("received DNS request for domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) // pass if non A/AAAA query if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA { - d.continueToNextHandler(w, r, "non A/AAAA query") + d.continueToNextHandler(w, r, logger, "non A/AAAA query") return } @@ -230,19 +233,19 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { d.mu.RUnlock() if peerKey == "" { - d.writeDNSError(w, r, "no current peer key") + d.writeDNSError(w, r, logger, "no current peer key") return } upstreamIP, err := d.getUpstreamIP(peerKey) if err != nil { - d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err)) + d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err)) return } client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) if err != nil { - d.writeDNSError(w, r, fmt.Sprintf("create DNS client: %v", err)) + d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) return } @@ -253,9 +256,9 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) if err != nil { - log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { - log.Errorf("failed writing DNS response: %v", err) + logger.Errorf("failed writing DNS response: %v", err) } return } @@ -265,34 +268,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { answer = reply.Answer } - log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) + logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) reply.Id = r.Id if err := d.writeMsg(w, reply); err != nil { - log.Errorf("failed writing DNS response: %v", err) + logger.Errorf("failed writing DNS response: %v", err) } } -func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) { - log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason) +func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) { + logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason) resp := new(dns.Msg) resp.SetRcode(r, dns.RcodeServerFailure) if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed to write DNS error response: %v", err) + logger.Errorf("failed to write DNS error response: %v", err) } } // continueToNextHandler signals the handler chain to try the next handler -func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { - log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) +func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) { + logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) resp := new(dns.Msg) resp.SetRcode(r, dns.RcodeNameError) // Set Zero bit to signal handler chain to continue resp.MsgHdr.Zero = true if err := w.WriteMsg(resp); err != nil { - log.Errorf("failed writing DNS continue response: %v", err) + logger.Errorf("failed writing DNS continue response: %v", err) } } diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index f7072c6b8..5441f3481 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -12,6 +12,8 @@ import ( "fyne.io/fyne/v2" "fyne.io/systray" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/version" ) type eventHandler struct { @@ -143,7 +145,7 @@ func (h *eventHandler) handleGitHubClick() { } func (h *eventHandler) handleUpdateClick() { - if err := openURL("https://netbird.io/download"); err != nil { + if err := openURL(version.DownloadUrl()); err != nil { log.Errorf("failed to open download URL: %v", err) } }