separate fqdn and domain validation

This commit is contained in:
Pascal Fischer
2025-08-08 15:19:13 +02:00
parent 1c1706753d
commit 30b387ba02
5 changed files with 113 additions and 37 deletions

View File

@@ -667,7 +667,7 @@ func validateDnsLabels(labels []string) (domain.List, error) {
return domains, nil return domains, nil
} }
domains, err = domain.ValidateDomains(labels) domains, err = domain.ValidateFQDNs(labels)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to validate dns labels: %v", err) return nil, fmt.Errorf("failed to validate dns labels: %v", err)
} }

View File

@@ -8,13 +8,13 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/route"
) )
const failedToConvertRoute = "failed to convert route to response: %v" const failedToConvertRoute = "failed to convert route to response: %v"
@@ -94,7 +94,7 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
var networkType route.NetworkType var networkType route.NetworkType
var newPrefix netip.Prefix var newPrefix netip.Prefix
if req.Domains != nil { if req.Domains != nil {
d, err := domain.ValidateDomains(*req.Domains) d, err := domain.ValidateFQDNs(*req.Domains)
if err != nil { if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return return
@@ -217,7 +217,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
} }
if req.Domains != nil { if req.Domains != nil {
d, err := domain.ValidateDomains(*req.Domains) d, err := domain.ValidateFQDNs(*req.Domains)
if err != nil { if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
return return

View File

@@ -539,7 +539,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
} }
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil { if err := domain.ValidateFQDNsList(peer.ExtraDNSLabels); err != nil {
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err) return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
} }

View File

@@ -7,11 +7,13 @@ import (
"sync" "sync"
) )
const maxDomains = 32 const maxFQDN = 32
var regexCache = map[string]*regexp.Regexp{} var regexCache = map[string]*regexp.Regexp{}
var regexCacheMu sync.Mutex var regexCacheMu sync.Mutex
var fqdnRegex = regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
func buildDomainRegex(allowWildcard, allowSingleToplevel bool) *regexp.Regexp { func buildDomainRegex(allowWildcard, allowSingleToplevel bool) *regexp.Regexp {
key := fmt.Sprintf("%t:%t", allowWildcard, allowSingleToplevel) key := fmt.Sprintf("%t:%t", allowWildcard, allowSingleToplevel)
@@ -44,19 +46,19 @@ func buildDomainRegex(allowWildcard, allowSingleToplevel bool) *regexp.Regexp {
return re return re
} }
// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. // ValidateFQDNs checks if each domain in the list is valid and returns a punycode-encoded DomainList.
func ValidateDomains(domains []string) (List, error) { func ValidateFQDNs(fqdns []string) (List, error) {
if len(domains) == 0 { if len(fqdns) == 0 {
return nil, fmt.Errorf("domains list is empty") return nil, fmt.Errorf("fqdns list is empty")
} }
if len(domains) > maxDomains { if len(fqdns) > maxFQDN {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) return nil, fmt.Errorf("fqdns list exceeds maximum allowed fqdns: %d", maxFQDN)
} }
var domainList List var domainList List
for _, d := range domains { for _, d := range fqdns {
validDomain, err := ToValidDomain(d, true, false) validDomain, err := ToValidFQDN(d)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid domain %s: %w", d, err) return nil, fmt.Errorf("invalid domain %s: %w", d, err)
} }
@@ -65,21 +67,19 @@ func ValidateDomains(domains []string) (List, error) {
return domainList, nil return domainList, nil
} }
// ValidateDomainsList checks if each domain in the list is valid // ValidateFQDNsList checks if each domain in the list is valid
func ValidateDomainsList(domains []string) error { func ValidateFQDNsList(fqdns []string) error {
if len(domains) == 0 { if len(fqdns) == 0 {
return nil return nil
} }
if len(domains) > maxDomains { if len(fqdns) > maxFQDN {
return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) return fmt.Errorf("fqdns list exceeds maximum allowed fqdns: %d", maxFQDN)
} }
domainRegex := buildDomainRegex(false, true) for _, d := range fqdns {
for _, d := range domains {
d := strings.ToLower(d) d := strings.ToLower(d)
if !domainRegex.MatchString(d) { if !fqdnRegex.MatchString(d) {
return fmt.Errorf("invalid domain format: %s", d) return fmt.Errorf("invalid fqdns format: %s", d)
} }
} }
return nil return nil
@@ -112,3 +112,18 @@ func ToValidDomain(domain string, allowWildcard, allowSingleToplevel bool) (Doma
return punycode, nil return punycode, nil
} }
// ToValidFQDN converts a domain to a valid fqdn format.
func ToValidFQDN(domain string) (Domain, error) {
// handles length and idna conversion
punycode, err := FromString(domain)
if err != nil {
return "", fmt.Errorf("convert domain to punycode: %s: %w", domain, err)
}
if !fqdnRegex.MatchString(string(punycode)) {
return "", fmt.Errorf("invalid domain format: %s", domain)
}
return punycode, nil
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestValidateDomains(t *testing.T) { func TestValidateFQDNs(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
domains []string domains []string
@@ -63,10 +63,10 @@ func TestValidateDomains(t *testing.T) {
wantErr: true, wantErr: true,
}, },
{ {
name: "Valid wildcard domain", name: "Invalid wildcard domain",
domains: []string{"*.example.com"}, domains: []string{"*.example.com"},
expected: List{"*.example.com"}, expected: nil,
wantErr: false, wantErr: true,
}, },
{ {
name: "Wildcard with dot domain", name: "Wildcard with dot domain",
@@ -90,16 +90,16 @@ func TestValidateDomains(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := ValidateDomains(tt.domains) got, err := ValidateFQDNs(tt.domains)
assert.Equal(t, tt.wantErr, err != nil) assert.Equal(t, tt.wantErr, err != nil)
assert.Equal(t, got, tt.expected) assert.Equal(t, got, tt.expected)
}) })
} }
} }
func TestValidateDomainsList(t *testing.T) { func TestValidateFQDNsList(t *testing.T) {
validDomains := make([]string, maxDomains) validDomains := make([]string, maxFQDN)
for i := range maxDomains { for i := range maxFQDN {
validDomains[i] = fmt.Sprintf("example%d.com", i) validDomains[i] = fmt.Sprintf("example%d.com", i)
} }
@@ -124,7 +124,7 @@ func TestValidateDomainsList(t *testing.T) {
wantErr: false, wantErr: false,
}, },
{ {
// Unlike ValidateDomains (which converts to punycode), // Unlike ValidateFQDNs (which converts to punycode),
// ValidateDomainsStrSlice will fail on non-ASCII domain chars. // ValidateDomainsStrSlice will fail on non-ASCII domain chars.
name: "Unicode domain fails (no punycode conversion)", name: "Unicode domain fails (no punycode conversion)",
domains: []string{"münchen.de"}, domains: []string{"münchen.de"},
@@ -161,12 +161,12 @@ func TestValidateDomainsList(t *testing.T) {
wantErr: true, wantErr: true,
}, },
{ {
name: "Exactly maxDomains items (valid)", name: "Exactly maxFQDN items (valid)",
domains: validDomains, domains: validDomains,
wantErr: false, wantErr: false,
}, },
{ {
name: "Exceeds maxDomains items", name: "Exceeds maxFQDN items",
domains: append(validDomains, "extra.com"), domains: append(validDomains, "extra.com"),
wantErr: true, wantErr: true,
}, },
@@ -174,7 +174,7 @@ func TestValidateDomainsList(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
err := ValidateDomainsList(tt.domains) err := ValidateFQDNsList(tt.domains)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
@@ -183,3 +183,64 @@ func TestValidateDomainsList(t *testing.T) {
}) })
} }
} }
func TestIsValidDomain(t *testing.T) {
tests := []struct {
name string
domain string
valid bool
}{
{
name: "Empty domain",
domain: "",
valid: false,
},
{
name: "Single valid ASCII domain",
domain: "sub.ex-ample.com",
valid: true,
},
{
name: "Underscores in labels",
domain: "_jabber._tcp.gmail.com",
valid: false,
},
{
name: "Unicode domain fails (no punycode conversion)",
domain: "münchen.de",
valid: true,
},
{
name: "Invalid domain format - leading dash",
domain: "-example.com",
valid: false,
},
{
name: "Invalid domain format - trailing dash",
domain: "example-.com",
valid: false,
},
{
name: "Valid wildcard domain",
domain: "*.example.com",
valid: true,
},
{
name: "Wildcard with leading dot - invalid",
domain: ".*.example.com",
valid: false,
},
{
name: "Invalid wildcard with multiple asterisks",
domain: "a.*.example.com",
valid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
valid := IsValidDomain(tt.domain, true, true)
assert.Equal(t, tt.valid, valid)
})
}
}