diff --git a/management/server/account.go b/management/server/account.go index 333c9cca7..47e09b95e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -224,7 +224,7 @@ func BuildManager( // enable single account mode only if configured by user and number of existing accounts is not grater than 1 am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1 if am.singleAccountMode { - if !domain.IsValidDomain(singleAccountModeDomain) { + if !domain.IsValidDomain(singleAccountModeDomain, false, false) { return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain @@ -401,7 +401,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - if newSettings.DNSDomain != "" && !domain.IsValidDomain(newSettings.DNSDomain) { + if newSettings.DNSDomain != "" && !domain.IsValidDomain(newSettings.DNSDomain, false, true) { return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } @@ -1518,7 +1518,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return userAuth.AccountId, nil } - if userAuth.DomainCategory != types.PrivateCategory || !domain.IsValidDomain(userAuth.Domain) { + if userAuth.DomainCategory != types.PrivateCategory || !domain.IsValidDomain(userAuth.Domain, false, false) { return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain) } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index a9b7e3cf7..380625ecb 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -2,7 +2,6 @@ package server import ( "context" - "strings" "unicode/utf8" "github.com/rs/xid" @@ -263,10 +262,7 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo } for _, domain := range domains { - if strings.HasPrefix(domain, "*") { - return status.Errorf(status.InvalidArgument, "wildcard prefix is not allowed: %s", domain) - } - if !nbDomain.IsValidDomain(domain) { + if !nbDomain.IsValidDomain(domain, false, true) { return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain) } } diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 1c5248f88..2f9347394 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -165,7 +165,7 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil } - if domain, err := nbDomain.ToValidDomain(address); err == nil { + if domain, err := nbDomain.ToValidDomain(address, true, false); err == nil { return Domain, string(domain), netip.Prefix{}, nil } diff --git a/shared/management/domain/validate.go b/shared/management/domain/validate.go index 6a52e636f..1d508a331 100644 --- a/shared/management/domain/validate.go +++ b/shared/management/domain/validate.go @@ -4,11 +4,45 @@ import ( "fmt" "regexp" "strings" + "sync" ) const maxDomains = 32 -var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)(?:\.(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`) +var regexCache = map[string]*regexp.Regexp{} +var regexCacheMu sync.Mutex + +func buildDomainRegex(allowWildcard, allowSingleToplevel bool) *regexp.Regexp { + key := fmt.Sprintf("%t:%t", allowWildcard, allowSingleToplevel) + + regexCacheMu.Lock() + defer regexCacheMu.Unlock() + + if re, ok := regexCache[key]; ok { + return re + } + + var pattern strings.Builder + pattern.WriteString("^") + + if allowWildcard { + pattern.WriteString(`(?:\*\.)?`) + } + + label := `(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?` + + if allowSingleToplevel { + pattern.WriteString(label + `(?:\.` + label + `)*`) + } else { + pattern.WriteString(label + `(?:\.` + label + `)+`) + } + + pattern.WriteString("$") + + re := regexp.MustCompile(pattern.String()) + regexCache[key] = re + return re +} // ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList. func ValidateDomains(domains []string) (List, error) { @@ -22,7 +56,7 @@ func ValidateDomains(domains []string) (List, error) { var domainList List for _, d := range domains { - validDomain, err := ToValidDomain(d) + validDomain, err := ToValidDomain(d, false, false) if err != nil { return nil, fmt.Errorf("invalid domain %s: %w", d, err) } @@ -40,6 +74,8 @@ func ValidateDomainsList(domains []string) error { return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) } + domainRegex := buildDomainRegex(false, true) + for _, d := range domains { d := strings.ToLower(d) if !domainRegex.MatchString(d) { @@ -50,24 +86,26 @@ func ValidateDomainsList(domains []string) error { } // IsValidDomain checks if the given domain is valid. -func IsValidDomain(domain string) bool { +func IsValidDomain(domain string, allowWildcard, allowSingleToplevel bool) bool { // handles length and idna conversion punycode, err := FromString(domain) if err != nil { return false } + domainRegex := buildDomainRegex(allowWildcard, allowSingleToplevel) return domainRegex.MatchString(string(punycode)) } // ToValidDomain converts a domain to a valid domain format. -func ToValidDomain(domain string) (Domain, error) { +func ToValidDomain(domain string, allowWildcard, allowSingleToplevel bool) (Domain, error) { // handles length and idna conversion punycode, err := FromString(domain) if err != nil { return "", fmt.Errorf("convert domain to punycode: %s: %w", domain, err) } + domainRegex := buildDomainRegex(allowWildcard, allowSingleToplevel) if !domainRegex.MatchString(string(punycode)) { return "", fmt.Errorf("invalid domain format: %s", domain) }