mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
* Exclude Flow domain from caching to prevent TLS failures due to stale records. * Fix test
418 lines
12 KiB
Go
418 lines
12 KiB
Go
package mgmt
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/miekg/dns"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
|
"github.com/netbirdio/netbird/shared/management/domain"
|
|
)
|
|
|
|
func TestResolver_NewResolver(t *testing.T) {
|
|
resolver := NewResolver()
|
|
|
|
assert.NotNil(t, resolver)
|
|
assert.NotNil(t, resolver.records)
|
|
assert.False(t, resolver.MatchSubdomains())
|
|
}
|
|
|
|
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
urlStr string
|
|
expectedDom string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "HTTPS URL with port",
|
|
urlStr: "https://api.netbird.io:443",
|
|
expectedDom: "api.netbird.io",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "HTTP URL without port",
|
|
urlStr: "http://signal.example.com",
|
|
expectedDom: "signal.example.com",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "URL with path",
|
|
urlStr: "https://relay.netbird.io/status",
|
|
expectedDom: "relay.netbird.io",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Invalid URL",
|
|
urlStr: "not-a-valid-url",
|
|
expectedDom: "not-a-valid-url",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Empty URL",
|
|
urlStr: "",
|
|
expectedDom: "",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "STUN URL",
|
|
urlStr: "stun:stun.example.com:3478",
|
|
expectedDom: "stun.example.com",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "TURN URL",
|
|
urlStr: "turn:turn.example.com:3478",
|
|
expectedDom: "turn.example.com",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "REL URL",
|
|
urlStr: "rel://relay.example.com:443",
|
|
expectedDom: "relay.example.com",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "RELS URL",
|
|
urlStr: "rels://relay.example.com:443",
|
|
expectedDom: "relay.example.com",
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var parsedURL *url.URL
|
|
var err error
|
|
|
|
if tt.urlStr != "" {
|
|
parsedURL, err = url.Parse(tt.urlStr)
|
|
if err != nil && !tt.expectError {
|
|
t.Fatalf("Failed to parse URL: %v", err)
|
|
}
|
|
}
|
|
|
|
domain, err := extractDomainFromURL(parsedURL)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, tt.expectedDom, domain.SafeString())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResolver_PopulateFromConfig(t *testing.T) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
resolver := NewResolver()
|
|
|
|
// Test with IP address - should return error since IP addresses are rejected
|
|
mgmtURL, _ := url.Parse("https://127.0.0.1")
|
|
|
|
err := resolver.PopulateFromConfig(ctx, mgmtURL)
|
|
assert.Error(t, err)
|
|
assert.ErrorIs(t, err, dnsconfig.ErrIPNotAllowed)
|
|
|
|
// No domains should be cached when using IP addresses
|
|
domains := resolver.GetCachedDomains()
|
|
assert.Equal(t, 0, len(domains), "No domains should be cached when using IP addresses")
|
|
}
|
|
|
|
func TestResolver_ServeDNS(t *testing.T) {
|
|
resolver := NewResolver()
|
|
ctx := context.Background()
|
|
|
|
// Add a test domain to the cache - use example.org which is reserved for testing
|
|
testDomain, err := domain.FromString("example.org")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create domain: %v", err)
|
|
}
|
|
err = resolver.AddDomain(ctx, testDomain)
|
|
if err != nil {
|
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
|
}
|
|
|
|
// Test A record query for cached domain
|
|
t.Run("Cached domain A record", func(t *testing.T) {
|
|
var capturedMsg *dns.Msg
|
|
mockWriter := &test.MockResponseWriter{
|
|
WriteMsgFunc: func(m *dns.Msg) error {
|
|
capturedMsg = m
|
|
return nil
|
|
},
|
|
}
|
|
|
|
req := new(dns.Msg)
|
|
req.SetQuestion("example.org.", dns.TypeA)
|
|
|
|
resolver.ServeDNS(mockWriter, req)
|
|
|
|
assert.NotNil(t, capturedMsg)
|
|
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
|
|
assert.True(t, len(capturedMsg.Answer) > 0, "Should have at least one answer")
|
|
})
|
|
|
|
// Test uncached domain signals to continue to next handler
|
|
t.Run("Uncached domain signals continue to next handler", func(t *testing.T) {
|
|
var capturedMsg *dns.Msg
|
|
mockWriter := &test.MockResponseWriter{
|
|
WriteMsgFunc: func(m *dns.Msg) error {
|
|
capturedMsg = m
|
|
return nil
|
|
},
|
|
}
|
|
|
|
req := new(dns.Msg)
|
|
req.SetQuestion("unknown.example.com.", dns.TypeA)
|
|
|
|
resolver.ServeDNS(mockWriter, req)
|
|
|
|
assert.NotNil(t, capturedMsg)
|
|
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
|
|
// Zero flag set to true signals the handler chain to continue to next handler
|
|
assert.True(t, capturedMsg.MsgHdr.Zero, "Zero flag should be set to signal continuation to next handler")
|
|
assert.Empty(t, capturedMsg.Answer, "Should have no answers for uncached domain")
|
|
})
|
|
|
|
// Test that subdomains of cached domains are NOT resolved
|
|
t.Run("Subdomains of cached domains are not resolved", func(t *testing.T) {
|
|
var capturedMsg *dns.Msg
|
|
mockWriter := &test.MockResponseWriter{
|
|
WriteMsgFunc: func(m *dns.Msg) error {
|
|
capturedMsg = m
|
|
return nil
|
|
},
|
|
}
|
|
|
|
// Query for a subdomain of our cached domain
|
|
req := new(dns.Msg)
|
|
req.SetQuestion("sub.example.org.", dns.TypeA)
|
|
|
|
resolver.ServeDNS(mockWriter, req)
|
|
|
|
assert.NotNil(t, capturedMsg)
|
|
assert.Equal(t, dns.RcodeNameError, capturedMsg.Rcode)
|
|
assert.True(t, capturedMsg.MsgHdr.Zero, "Should signal continuation to next handler for subdomains")
|
|
assert.Empty(t, capturedMsg.Answer, "Should have no answers for subdomains")
|
|
})
|
|
|
|
// Test case-insensitive matching
|
|
t.Run("Case-insensitive domain matching", func(t *testing.T) {
|
|
var capturedMsg *dns.Msg
|
|
mockWriter := &test.MockResponseWriter{
|
|
WriteMsgFunc: func(m *dns.Msg) error {
|
|
capturedMsg = m
|
|
return nil
|
|
},
|
|
}
|
|
|
|
// Query with different casing
|
|
req := new(dns.Msg)
|
|
req.SetQuestion("EXAMPLE.ORG.", dns.TypeA)
|
|
|
|
resolver.ServeDNS(mockWriter, req)
|
|
|
|
assert.NotNil(t, capturedMsg)
|
|
assert.Equal(t, dns.RcodeSuccess, capturedMsg.Rcode)
|
|
assert.True(t, len(capturedMsg.Answer) > 0, "Should resolve regardless of case")
|
|
})
|
|
}
|
|
|
|
func TestResolver_GetCachedDomains(t *testing.T) {
|
|
resolver := NewResolver()
|
|
ctx := context.Background()
|
|
|
|
testDomain, err := domain.FromString("example.org")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create domain: %v", err)
|
|
}
|
|
err = resolver.AddDomain(ctx, testDomain)
|
|
if err != nil {
|
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
|
}
|
|
|
|
cachedDomains := resolver.GetCachedDomains()
|
|
|
|
assert.Equal(t, 1, len(cachedDomains), "Should return exactly one domain for single added domain")
|
|
assert.Equal(t, testDomain.SafeString(), cachedDomains[0].SafeString(), "Cached domain should match original")
|
|
assert.False(t, strings.HasSuffix(cachedDomains[0].PunycodeString(), "."), "Domain should not have trailing dot")
|
|
}
|
|
|
|
func TestResolver_ManagementDomainProtection(t *testing.T) {
|
|
resolver := NewResolver()
|
|
ctx := context.Background()
|
|
|
|
mgmtURL, _ := url.Parse("https://example.org")
|
|
err := resolver.PopulateFromConfig(ctx, mgmtURL)
|
|
if err != nil {
|
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
|
}
|
|
|
|
initialDomains := resolver.GetCachedDomains()
|
|
if len(initialDomains) == 0 {
|
|
t.Skip("Management domain failed to resolve, skipping test")
|
|
}
|
|
assert.Equal(t, 1, len(initialDomains), "Should have management domain cached")
|
|
assert.Equal(t, "example.org", initialDomains[0].SafeString())
|
|
|
|
serverDomains := dnsconfig.ServerDomains{
|
|
Signal: "google.com",
|
|
Relay: []domain.Domain{"cloudflare.com"},
|
|
}
|
|
|
|
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains)
|
|
if err != nil {
|
|
t.Logf("Server domains update failed: %v", err)
|
|
}
|
|
|
|
finalDomains := resolver.GetCachedDomains()
|
|
|
|
managementStillCached := false
|
|
for _, d := range finalDomains {
|
|
if d.SafeString() == "example.org" {
|
|
managementStillCached = true
|
|
break
|
|
}
|
|
}
|
|
assert.True(t, managementStillCached, "Management domain should never be removed")
|
|
}
|
|
|
|
// extractDomainFromURL extracts a domain from a URL - test helper function
|
|
func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
|
|
if u == nil {
|
|
return "", fmt.Errorf("URL is nil")
|
|
}
|
|
return dnsconfig.ExtractValidDomain(u.String())
|
|
}
|
|
|
|
func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
|
|
resolver := NewResolver()
|
|
ctx := context.Background()
|
|
|
|
// Set up initial domains using resolvable domains
|
|
initialDomains := dnsconfig.ServerDomains{
|
|
Signal: "example.org",
|
|
Stuns: []domain.Domain{"google.com"},
|
|
Turns: []domain.Domain{"cloudflare.com"},
|
|
}
|
|
|
|
// Add initial domains
|
|
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
|
|
if err != nil {
|
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
|
}
|
|
|
|
// Verify domains were added
|
|
cachedDomains := resolver.GetCachedDomains()
|
|
assert.Len(t, cachedDomains, 3)
|
|
|
|
// Update with empty ServerDomains (simulating partial network map update)
|
|
emptyDomains := dnsconfig.ServerDomains{}
|
|
removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains)
|
|
assert.NoError(t, err)
|
|
|
|
// Verify no domains were removed
|
|
assert.Len(t, removedDomains, 0, "No domains should be removed when update is empty")
|
|
|
|
// Verify all original domains are still cached
|
|
finalDomains := resolver.GetCachedDomains()
|
|
assert.Len(t, finalDomains, 3, "All original domains should still be cached")
|
|
}
|
|
|
|
func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
|
|
resolver := NewResolver()
|
|
ctx := context.Background()
|
|
|
|
// Set up initial complete domains using resolvable domains
|
|
initialDomains := dnsconfig.ServerDomains{
|
|
Signal: "example.org",
|
|
Stuns: []domain.Domain{"google.com"},
|
|
Turns: []domain.Domain{"cloudflare.com"},
|
|
}
|
|
|
|
// Add initial domains
|
|
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
|
|
if err != nil {
|
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
|
}
|
|
assert.Len(t, resolver.GetCachedDomains(), 3)
|
|
|
|
// Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn)
|
|
partialDomains := dnsconfig.ServerDomains{
|
|
Signal: "github.com",
|
|
}
|
|
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
|
|
if err != nil {
|
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
|
}
|
|
|
|
// Should remove only the old signal domain
|
|
assert.Len(t, removedDomains, 1, "Should remove only the old signal domain")
|
|
assert.Equal(t, "example.org", removedDomains[0].SafeString())
|
|
|
|
finalDomains := resolver.GetCachedDomains()
|
|
assert.Len(t, finalDomains, 3, "Should have new signal plus preserved stun/turn domains")
|
|
|
|
domainStrings := make([]string, len(finalDomains))
|
|
for i, d := range finalDomains {
|
|
domainStrings[i] = d.SafeString()
|
|
}
|
|
assert.Contains(t, domainStrings, "github.com")
|
|
assert.Contains(t, domainStrings, "google.com")
|
|
assert.Contains(t, domainStrings, "cloudflare.com")
|
|
assert.NotContains(t, domainStrings, "example.org")
|
|
}
|
|
|
|
func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
|
resolver := NewResolver()
|
|
ctx := context.Background()
|
|
|
|
// Set up initial complete domains using resolvable domains
|
|
initialDomains := dnsconfig.ServerDomains{
|
|
Signal: "example.org",
|
|
Stuns: []domain.Domain{"google.com"},
|
|
Turns: []domain.Domain{"cloudflare.com"},
|
|
}
|
|
|
|
// Add initial domains
|
|
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
|
|
if err != nil {
|
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
|
}
|
|
assert.Len(t, resolver.GetCachedDomains(), 3)
|
|
|
|
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
|
|
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
|
|
partialDomains := dnsconfig.ServerDomains{
|
|
Flow: "github.com",
|
|
}
|
|
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
|
|
if err != nil {
|
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
|
}
|
|
|
|
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
|
|
|
|
finalDomains := resolver.GetCachedDomains()
|
|
assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
|
|
|
|
domainStrings := make([]string, len(finalDomains))
|
|
for i, d := range finalDomains {
|
|
domainStrings[i] = d.SafeString()
|
|
}
|
|
assert.Contains(t, domainStrings, "example.org")
|
|
assert.Contains(t, domainStrings, "google.com")
|
|
assert.Contains(t, domainStrings, "cloudflare.com")
|
|
assert.NotContains(t, domainStrings, "github.com")
|
|
}
|