diff --git a/management/server/account.go b/management/server/account.go index 0f60bc91c..333c9cca7 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -9,7 +9,6 @@ import ( "net/netip" "os" "reflect" - "regexp" "slices" "strconv" "strings" @@ -45,6 +44,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) @@ -224,7 +224,7 @@ func BuildManager( // enable single account mode only if configured by user and number of existing accounts is not grater than 1 am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1 if am.singleAccountMode { - if !isDomainValid(singleAccountModeDomain) { + if !domain.IsValidDomain(singleAccountModeDomain) { return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain @@ -401,7 +401,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) { + if newSettings.DNSDomain != "" && !domain.IsValidDomain(newSettings.DNSDomain) { return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } @@ -1518,7 +1518,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return userAuth.AccountId, nil } - if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) { + if userAuth.DomainCategory != types.PrivateCategory || !domain.IsValidDomain(userAuth.Domain) { return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain) } @@ -1701,12 +1701,6 @@ func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool { return am.peersUpdateManager.HasChannel(peerID) } -var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`) - -func isDomainValid(domain string) bool { - return invalidDomainRegexp.MatchString(domain) -} - // GetDNSDomain returns the configured dnsDomain func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string { if settings == nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 1ee8805fc..0a55d80f7 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -2,11 +2,8 @@ package server import ( "context" - "errors" - "regexp" "unicode/utf8" - "github.com/miekg/dns" "github.com/rs/xid" nbdns "github.com/netbirdio/netbird/dns" @@ -15,13 +12,10 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + nbDomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) -const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$` - -var invalidDomainName = errors.New("invalid domain name") - // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read) @@ -268,8 +262,8 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo } for _, domain := range domains { - if err := validateDomain(domain); err != nil { - return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err) + if nbDomain.IsValidDomain(domain) { + return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain) } } return nil @@ -313,18 +307,3 @@ func validateGroups(list []string, groups map[string]*types.Group) error { return nil } - -var domainMatcher = regexp.MustCompile(domainPattern) - -func validateDomain(domain string) error { - if !domainMatcher.MatchString(domain) { - return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces") - } - - _, valid := dns.IsDomainName(domain) - if !valid { - return invalidDomainName - } - - return nil -} diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 7874be858..1c5248f88 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -4,15 +4,14 @@ import ( "errors" "fmt" "net/netip" - "regexp" "github.com/rs/xid" - nbDomain "github.com/netbirdio/netbird/shared/management/domain" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/route" + nbDomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" ) @@ -166,9 +165,8 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil } - domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) - if domainRegex.MatchString(address) { - return Domain, address, netip.Prefix{}, nil + if domain, err := nbDomain.ToValidDomain(address); err == nil { + return Domain, string(domain), netip.Prefix{}, nil } return "", "", netip.Prefix{}, errors.New("not a valid host, subnet, or domain") diff --git a/shared/management/domain/validate.go b/shared/management/domain/validate.go index bf2af7116..6744c704b 100644 --- a/shared/management/domain/validate.go +++ b/shared/management/domain/validate.go @@ -22,17 +22,11 @@ func ValidateDomains(domains []string) (List, error) { var domainList List for _, d := range domains { - // handles length and idna conversion - punycode, err := FromString(d) + validDomain, err := ToValidDomain(d) if err != nil { - return domainList, fmt.Errorf("convert domain to punycode: %s: %w", d, err) + return nil, fmt.Errorf("invalid domain %s: %w", d, err) } - - if !domainRegex.MatchString(string(punycode)) { - return domainList, fmt.Errorf("invalid domain format: %s", d) - } - - domainList = append(domainList, punycode) + domainList = append(domainList, validDomain) } return domainList, nil } @@ -54,3 +48,29 @@ func ValidateDomainsList(domains []string) error { } return nil } + +// IsValidDomain checks if the given domain is valid. +func IsValidDomain(domain string) bool { + // handles length and idna conversion + punycode, err := FromString(domain) + if err != nil { + return false + } + + return !domainRegex.MatchString(string(punycode)) +} + +// ToValidDomain converts a domain to a valid domain format. +func ToValidDomain(domain string) (Domain, error) { + // handles length and idna conversion + punycode, err := FromString(domain) + if err != nil { + return "", fmt.Errorf("convert domain to punycode: %s: %w", domain, err) + } + + if !domainRegex.MatchString(string(punycode)) { + return "", fmt.Errorf("invalid domain format: %s", domain) + } + + return punycode, nil +}