diff --git a/client/cmd/up.go b/client/cmd/up.go index 8732a687d..8af3b4a50 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -667,7 +667,7 @@ func validateDnsLabels(labels []string) (domain.List, error) { return domains, nil } - domains, err = domain.ValidateDomains(labels) + domains, err = domain.ValidateFQDNs(labels) if err != nil { return nil, fmt.Errorf("failed to validate dns labels: %v", err) } diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 7950db1e8..112d43c08 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -8,13 +8,13 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/management/server/account" nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - "github.com/netbirdio/netbird/route" ) const failedToConvertRoute = "failed to convert route to response: %v" @@ -94,7 +94,7 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { var networkType route.NetworkType var newPrefix netip.Prefix if req.Domains != nil { - d, err := domain.ValidateDomains(*req.Domains) + d, err := domain.ValidateFQDNs(*req.Domains) if err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) return @@ -217,7 +217,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { } if req.Domains != nil { - d, err := domain.ValidateDomains(*req.Domains) + d, err := domain.ValidateFQDNs(*req.Domains) if err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w) return diff --git a/management/server/peer.go b/management/server/peer.go index d72eac91a..eede5c8c3 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -539,7 +539,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil { + if err := domain.ValidateFQDNsList(peer.ExtraDNSLabels); err != nil { return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err) } diff --git a/shared/management/domain/validate.go b/shared/management/domain/validate.go index 74a7901c1..7100eeb4c 100644 --- a/shared/management/domain/validate.go +++ b/shared/management/domain/validate.go @@ -7,11 +7,13 @@ import ( "sync" ) -const maxDomains = 32 +const maxFQDN = 32 var regexCache = map[string]*regexp.Regexp{} var regexCacheMu sync.Mutex +var fqdnRegex = regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) + func buildDomainRegex(allowWildcard, allowSingleToplevel bool) *regexp.Regexp { key := fmt.Sprintf("%t:%t", allowWildcard, allowSingleToplevel) @@ -44,19 +46,19 @@ func buildDomainRegex(allowWildcard, allowSingleToplevel bool) *regexp.Regexp { return re } -// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. -func ValidateDomains(domains []string) (List, error) { - if len(domains) == 0 { - return nil, fmt.Errorf("domains list is empty") +// ValidateFQDNs checks if each domain in the list is valid and returns a punycode-encoded DomainList. +func ValidateFQDNs(fqdns []string) (List, error) { + if len(fqdns) == 0 { + return nil, fmt.Errorf("fqdns list is empty") } - if len(domains) > maxDomains { - return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) + if len(fqdns) > maxFQDN { + return nil, fmt.Errorf("fqdns list exceeds maximum allowed fqdns: %d", maxFQDN) } var domainList List - for _, d := range domains { - validDomain, err := ToValidDomain(d, true, false) + for _, d := range fqdns { + validDomain, err := ToValidFQDN(d) if err != nil { return nil, fmt.Errorf("invalid domain %s: %w", d, err) } @@ -65,21 +67,19 @@ func ValidateDomains(domains []string) (List, error) { return domainList, nil } -// ValidateDomainsList checks if each domain in the list is valid -func ValidateDomainsList(domains []string) error { - if len(domains) == 0 { +// ValidateFQDNsList checks if each domain in the list is valid +func ValidateFQDNsList(fqdns []string) error { + if len(fqdns) == 0 { return nil } - if len(domains) > maxDomains { - return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) + if len(fqdns) > maxFQDN { + return fmt.Errorf("fqdns list exceeds maximum allowed fqdns: %d", maxFQDN) } - domainRegex := buildDomainRegex(false, true) - - for _, d := range domains { + for _, d := range fqdns { d := strings.ToLower(d) - if !domainRegex.MatchString(d) { - return fmt.Errorf("invalid domain format: %s", d) + if !fqdnRegex.MatchString(d) { + return fmt.Errorf("invalid fqdns format: %s", d) } } return nil @@ -112,3 +112,18 @@ func ToValidDomain(domain string, allowWildcard, allowSingleToplevel bool) (Doma return punycode, nil } + +// ToValidFQDN converts a domain to a valid fqdn format. +func ToValidFQDN(domain string) (Domain, error) { + // handles length and idna conversion + punycode, err := FromString(domain) + if err != nil { + return "", fmt.Errorf("convert domain to punycode: %s: %w", domain, err) + } + + if !fqdnRegex.MatchString(string(punycode)) { + return "", fmt.Errorf("invalid domain format: %s", domain) + } + + return punycode, nil +} diff --git a/shared/management/domain/validate_test.go b/shared/management/domain/validate_test.go index c52a8ee10..f71130f19 100644 --- a/shared/management/domain/validate_test.go +++ b/shared/management/domain/validate_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestValidateDomains(t *testing.T) { +func TestValidateFQDNs(t *testing.T) { tests := []struct { name string domains []string @@ -63,10 +63,10 @@ func TestValidateDomains(t *testing.T) { wantErr: true, }, { - name: "Valid wildcard domain", + name: "Invalid wildcard domain", domains: []string{"*.example.com"}, - expected: List{"*.example.com"}, - wantErr: false, + expected: nil, + wantErr: true, }, { name: "Wildcard with dot domain", @@ -90,16 +90,16 @@ func TestValidateDomains(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := ValidateDomains(tt.domains) + got, err := ValidateFQDNs(tt.domains) assert.Equal(t, tt.wantErr, err != nil) assert.Equal(t, got, tt.expected) }) } } -func TestValidateDomainsList(t *testing.T) { - validDomains := make([]string, maxDomains) - for i := range maxDomains { +func TestValidateFQDNsList(t *testing.T) { + validDomains := make([]string, maxFQDN) + for i := range maxFQDN { validDomains[i] = fmt.Sprintf("example%d.com", i) } @@ -124,7 +124,7 @@ func TestValidateDomainsList(t *testing.T) { wantErr: false, }, { - // Unlike ValidateDomains (which converts to punycode), + // Unlike ValidateFQDNs (which converts to punycode), // ValidateDomainsStrSlice will fail on non-ASCII domain chars. name: "Unicode domain fails (no punycode conversion)", domains: []string{"münchen.de"}, @@ -161,12 +161,12 @@ func TestValidateDomainsList(t *testing.T) { wantErr: true, }, { - name: "Exactly maxDomains items (valid)", + name: "Exactly maxFQDN items (valid)", domains: validDomains, wantErr: false, }, { - name: "Exceeds maxDomains items", + name: "Exceeds maxFQDN items", domains: append(validDomains, "extra.com"), wantErr: true, }, @@ -174,7 +174,7 @@ func TestValidateDomainsList(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := ValidateDomainsList(tt.domains) + err := ValidateFQDNsList(tt.domains) if tt.wantErr { assert.Error(t, err) } else { @@ -183,3 +183,64 @@ func TestValidateDomainsList(t *testing.T) { }) } } + +func TestIsValidDomain(t *testing.T) { + tests := []struct { + name string + domain string + valid bool + }{ + { + name: "Empty domain", + domain: "", + valid: false, + }, + { + name: "Single valid ASCII domain", + domain: "sub.ex-ample.com", + valid: true, + }, + { + name: "Underscores in labels", + domain: "_jabber._tcp.gmail.com", + valid: false, + }, + { + name: "Unicode domain fails (no punycode conversion)", + domain: "münchen.de", + valid: true, + }, + { + name: "Invalid domain format - leading dash", + domain: "-example.com", + valid: false, + }, + { + name: "Invalid domain format - trailing dash", + domain: "example-.com", + valid: false, + }, + { + name: "Valid wildcard domain", + domain: "*.example.com", + valid: true, + }, + { + name: "Wildcard with leading dot - invalid", + domain: ".*.example.com", + valid: false, + }, + { + name: "Invalid wildcard with multiple asterisks", + domain: "a.*.example.com", + valid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + valid := IsValidDomain(tt.domain, true, true) + assert.Equal(t, tt.valid, valid) + }) + } +}