mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Compare commits
6 Commits
merged-fix
...
chore/unif
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
78c886eb53 | ||
|
|
30b387ba02 | ||
|
|
1c1706753d | ||
|
|
b5da6d3f8e | ||
|
|
0af0447f1b | ||
|
|
6124405f94 |
@@ -667,7 +667,7 @@ func validateDnsLabels(labels []string) (domain.List, error) {
|
|||||||
return domains, nil
|
return domains, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
domains, err = domain.ValidateDomains(labels)
|
domains, err = domain.ValidateFQDNs(labels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
|
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"regexp"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -45,6 +44,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/management/server/util"
|
"github.com/netbirdio/netbird/management/server/util"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -224,7 +224,7 @@ func BuildManager(
|
|||||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||||
am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1
|
am.singleAccountMode = singleAccountModeDomain != "" && accountsCounter <= 1
|
||||||
if am.singleAccountMode {
|
if am.singleAccountMode {
|
||||||
if !isDomainValid(singleAccountModeDomain) {
|
if !domain.IsValidDomain(singleAccountModeDomain, false, false) {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
|
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain)
|
||||||
}
|
}
|
||||||
am.singleAccountModeDomain = singleAccountModeDomain
|
am.singleAccountModeDomain = singleAccountModeDomain
|
||||||
@@ -401,7 +401,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
|
|||||||
return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||||
}
|
}
|
||||||
|
|
||||||
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
|
if newSettings.DNSDomain != "" && !domain.IsValidDomain(newSettings.DNSDomain, false, true) {
|
||||||
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1518,7 +1518,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
|
|||||||
return userAuth.AccountId, nil
|
return userAuth.AccountId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if userAuth.DomainCategory != types.PrivateCategory || !isDomainValid(userAuth.Domain) {
|
if userAuth.DomainCategory != types.PrivateCategory || !domain.IsValidDomain(userAuth.Domain, false, false) {
|
||||||
return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain)
|
return am.GetAccountIDByUserID(ctx, userAuth.UserId, userAuth.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1701,12 +1701,6 @@ func (am *DefaultAccountManager) HasConnectedChannel(peerID string) bool {
|
|||||||
return am.peersUpdateManager.HasChannel(peerID)
|
return am.peersUpdateManager.HasChannel(peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
var invalidDomainRegexp = regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
|
|
||||||
|
|
||||||
func isDomainValid(domain string) bool {
|
|
||||||
return invalidDomainRegexp.MatchString(domain)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDNSDomain returns the configured dnsDomain
|
// GetDNSDomain returns the configured dnsDomain
|
||||||
func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string {
|
func (am *DefaultAccountManager) GetDNSDomain(settings *types.Settings) string {
|
||||||
if settings == nil {
|
if settings == nil {
|
||||||
|
|||||||
@@ -8,13 +8,13 @@ import (
|
|||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const failedToConvertRoute = "failed to convert route to response: %v"
|
const failedToConvertRoute = "failed to convert route to response: %v"
|
||||||
@@ -217,7 +217,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if req.Domains != nil {
|
if req.Domains != nil {
|
||||||
d, err := domain.ValidateDomains(*req.Domains)
|
d, err := domain.ValidateFQDNs(*req.Domains)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid domains: %v", err), w)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,11 +2,8 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"regexp"
|
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@@ -15,13 +12,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
nbDomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
|
|
||||||
|
|
||||||
var invalidDomainName = errors.New("invalid domain name")
|
|
||||||
|
|
||||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||||
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read)
|
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Nameservers, operations.Read)
|
||||||
@@ -268,8 +262,8 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
if err := validateDomain(domain); err != nil {
|
if !nbDomain.IsValidDomain(domain, false, true) {
|
||||||
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s %q", domain, err)
|
return status.Errorf(status.InvalidArgument, "nameserver group got an invalid domain: %s", domain)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -313,18 +307,3 @@ func validateGroups(list []string, groups map[string]*types.Group) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var domainMatcher = regexp.MustCompile(domainPattern)
|
|
||||||
|
|
||||||
func validateDomain(domain string) error {
|
|
||||||
if !domainMatcher.MatchString(domain) {
|
|
||||||
return errors.New("domain should consists of only letters, numbers, and hyphens with no leading, trailing hyphens, or spaces")
|
|
||||||
}
|
|
||||||
|
|
||||||
_, valid := dns.IsDomainName(domain)
|
|
||||||
if !valid {
|
|
||||||
return invalidDomainName
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -910,12 +910,12 @@ func TestValidateDomain(t *testing.T) {
|
|||||||
errFunc: require.NoError,
|
errFunc: require.NoError,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Valid domain name with trailing dot",
|
name: "Invalid domain name with trailing dot",
|
||||||
domain: "example.",
|
domain: "example.",
|
||||||
errFunc: require.NoError,
|
errFunc: require.Error,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Invalid wildcard domain name",
|
name: "Valid wildcard domain name",
|
||||||
domain: "*.example",
|
domain: "*.example",
|
||||||
errFunc: require.Error,
|
errFunc: require.Error,
|
||||||
},
|
},
|
||||||
@@ -932,7 +932,7 @@ func TestValidateDomain(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Invalid domain name with double hyphen",
|
name: "Invalid domain name with double hyphen",
|
||||||
domain: "test--example.com",
|
domain: "test--example.com",
|
||||||
errFunc: require.Error,
|
errFunc: require.NoError, // Note: Double hyphen is not valid but due to punicode hard to filter out
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Invalid domain name with a label exceeding 63 characters",
|
name: "Invalid domain name with a label exceeding 63 characters",
|
||||||
@@ -968,7 +968,7 @@ func TestValidateDomain(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range testCases {
|
for _, testCase := range testCases {
|
||||||
t.Run(testCase.name, func(t *testing.T) {
|
t.Run(testCase.name, func(t *testing.T) {
|
||||||
testCase.errFunc(t, validateDomain(testCase.domain))
|
testCase.errFunc(t, validateDomainInput(false, []string{testCase.domain}, false))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,15 +4,14 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"regexp"
|
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
nbDomain "github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
nbDomain "github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||||
)
|
)
|
||||||
@@ -166,9 +165,8 @@ func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix,
|
|||||||
return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil
|
return Host, "", netip.PrefixFrom(ip, ip.BitLen()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
|
if domain, err := nbDomain.ToValidDomain(address, true, false); err == nil {
|
||||||
if domainRegex.MatchString(address) {
|
return Domain, string(domain), netip.Prefix{}, nil
|
||||||
return Domain, address, netip.Prefix{}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", "", netip.Prefix{}, errors.New("not a valid host, subnet, or domain")
|
return "", "", netip.Prefix{}, errors.New("not a valid host, subnet, or domain")
|
||||||
|
|||||||
@@ -539,7 +539,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
|
if err := domain.ValidateFQDNsList(peer.ExtraDNSLabels); err != nil {
|
||||||
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,53 +4,147 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxDomains = 32
|
const maxFQDN = 32
|
||||||
|
|
||||||
var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
var regexCache = map[string]*regexp.Regexp{}
|
||||||
|
var regexCacheMu sync.Mutex
|
||||||
|
|
||||||
|
var fqdnRegex = regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
||||||
|
|
||||||
|
func buildDomainRegex(allowWildcard, allowSingleToplevel bool) *regexp.Regexp {
|
||||||
|
key := fmt.Sprintf("%t:%t", allowWildcard, allowSingleToplevel)
|
||||||
|
|
||||||
|
regexCacheMu.Lock()
|
||||||
|
defer regexCacheMu.Unlock()
|
||||||
|
|
||||||
|
if re, ok := regexCache[key]; ok {
|
||||||
|
return re
|
||||||
|
}
|
||||||
|
|
||||||
|
var pattern strings.Builder
|
||||||
|
pattern.WriteString("^")
|
||||||
|
|
||||||
|
if allowWildcard {
|
||||||
|
pattern.WriteString(`(?:\*\.)?`)
|
||||||
|
}
|
||||||
|
|
||||||
|
label := `(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?`
|
||||||
|
|
||||||
|
if allowSingleToplevel {
|
||||||
|
pattern.WriteString(label + `(?:\.` + label + `)*`)
|
||||||
|
} else {
|
||||||
|
pattern.WriteString(label + `(?:\.` + label + `)+`)
|
||||||
|
}
|
||||||
|
|
||||||
|
pattern.WriteString("$")
|
||||||
|
|
||||||
|
re := regexp.MustCompile(pattern.String())
|
||||||
|
regexCache[key] = re
|
||||||
|
return re
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateFQDNs checks if each domain in the list is valid and returns a punycode-encoded DomainList.
|
||||||
|
func ValidateFQDNs(fqdns []string) (List, error) {
|
||||||
|
if len(fqdns) == 0 {
|
||||||
|
return nil, fmt.Errorf("fqdns list is empty")
|
||||||
|
}
|
||||||
|
if len(fqdns) > maxFQDN {
|
||||||
|
return nil, fmt.Errorf("fqdns list exceeds maximum allowed fqdns: %d", maxFQDN)
|
||||||
|
}
|
||||||
|
|
||||||
|
var domainList List
|
||||||
|
|
||||||
|
for _, d := range fqdns {
|
||||||
|
validDomain, err := ToValidFQDN(d)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid domain %s: %w", d, err)
|
||||||
|
}
|
||||||
|
domainList = append(domainList, validDomain)
|
||||||
|
}
|
||||||
|
return domainList, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
|
// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
|
||||||
func ValidateDomains(domains []string) (List, error) {
|
func ValidateDomains(domains []string) (List, error) {
|
||||||
if len(domains) == 0 {
|
if len(domains) == 0 {
|
||||||
return nil, fmt.Errorf("domains list is empty")
|
return nil, fmt.Errorf("domains list is empty")
|
||||||
}
|
}
|
||||||
if len(domains) > maxDomains {
|
if len(domains) > maxFQDN {
|
||||||
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxFQDN)
|
||||||
}
|
}
|
||||||
|
|
||||||
var domainList List
|
var domainList List
|
||||||
|
|
||||||
for _, d := range domains {
|
for _, d := range domains {
|
||||||
// handles length and idna conversion
|
validDomain, err := ToValidDomain(d, true, true)
|
||||||
punycode, err := FromString(d)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return domainList, fmt.Errorf("convert domain to punycode: %s: %w", d, err)
|
return nil, fmt.Errorf("invalid domain %s: %w", d, err)
|
||||||
}
|
}
|
||||||
|
domainList = append(domainList, validDomain)
|
||||||
if !domainRegex.MatchString(string(punycode)) {
|
|
||||||
return domainList, fmt.Errorf("invalid domain format: %s", d)
|
|
||||||
}
|
|
||||||
|
|
||||||
domainList = append(domainList, punycode)
|
|
||||||
}
|
}
|
||||||
return domainList, nil
|
return domainList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateDomainsList checks if each domain in the list is valid
|
// ValidateFQDNsList checks if each domain in the list is valid
|
||||||
func ValidateDomainsList(domains []string) error {
|
func ValidateFQDNsList(fqdns []string) error {
|
||||||
if len(domains) == 0 {
|
if len(fqdns) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if len(domains) > maxDomains {
|
if len(fqdns) > maxFQDN {
|
||||||
return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
return fmt.Errorf("fqdns list exceeds maximum allowed fqdns: %d", maxFQDN)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, d := range domains {
|
for _, d := range fqdns {
|
||||||
d := strings.ToLower(d)
|
d := strings.ToLower(d)
|
||||||
if !domainRegex.MatchString(d) {
|
if !fqdnRegex.MatchString(d) {
|
||||||
return fmt.Errorf("invalid domain format: %s", d)
|
return fmt.Errorf("invalid fqdns format: %s", d)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsValidDomain checks if the given domain is valid.
|
||||||
|
func IsValidDomain(domain string, allowWildcard, allowSingleToplevel bool) bool {
|
||||||
|
// handles length and idna conversion
|
||||||
|
punycode, err := FromString(domain)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
domainRegex := buildDomainRegex(allowWildcard, allowSingleToplevel)
|
||||||
|
return domainRegex.MatchString(string(punycode))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToValidDomain converts a domain to a valid domain format.
|
||||||
|
func ToValidDomain(domain string, allowWildcard, allowSingleToplevel bool) (Domain, error) {
|
||||||
|
// handles length and idna conversion
|
||||||
|
punycode, err := FromString(domain)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("convert domain to punycode: %s: %w", domain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
domainRegex := buildDomainRegex(allowWildcard, allowSingleToplevel)
|
||||||
|
if !domainRegex.MatchString(string(punycode)) {
|
||||||
|
return "", fmt.Errorf("invalid domain format: %s", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
return punycode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToValidFQDN converts a domain to a valid fqdn format.
|
||||||
|
func ToValidFQDN(domain string) (Domain, error) {
|
||||||
|
// handles length and idna conversion
|
||||||
|
punycode, err := FromString(domain)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("convert domain to punycode: %s: %w", domain, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fqdnRegex.MatchString(string(punycode)) {
|
||||||
|
return "", fmt.Errorf("invalid domain format: %s", domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
return punycode, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestValidateDomains(t *testing.T) {
|
func TestValidateFQDNs(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
domains []string
|
domains []string
|
||||||
@@ -59,14 +59,14 @@ func TestValidateDomains(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Multiple domains valid and invalid",
|
name: "Multiple domains valid and invalid",
|
||||||
domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"},
|
domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"},
|
||||||
expected: List{"google.com"},
|
expected: nil,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Valid wildcard domain",
|
name: "Invalid wildcard domain",
|
||||||
domains: []string{"*.example.com"},
|
domains: []string{"*.example.com"},
|
||||||
expected: List{"*.example.com"},
|
expected: nil,
|
||||||
wantErr: false,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Wildcard with dot domain",
|
name: "Wildcard with dot domain",
|
||||||
@@ -90,16 +90,16 @@ func TestValidateDomains(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := ValidateDomains(tt.domains)
|
got, err := ValidateFQDNs(tt.domains)
|
||||||
assert.Equal(t, tt.wantErr, err != nil)
|
assert.Equal(t, tt.wantErr, err != nil)
|
||||||
assert.Equal(t, got, tt.expected)
|
assert.Equal(t, got, tt.expected)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateDomainsList(t *testing.T) {
|
func TestValidateFQDNsList(t *testing.T) {
|
||||||
validDomains := make([]string, maxDomains)
|
validDomains := make([]string, maxFQDN)
|
||||||
for i := range maxDomains {
|
for i := range maxFQDN {
|
||||||
validDomains[i] = fmt.Sprintf("example%d.com", i)
|
validDomains[i] = fmt.Sprintf("example%d.com", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ func TestValidateDomainsList(t *testing.T) {
|
|||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Unlike ValidateDomains (which converts to punycode),
|
// Unlike ValidateFQDNs (which converts to punycode),
|
||||||
// ValidateDomainsStrSlice will fail on non-ASCII domain chars.
|
// ValidateDomainsStrSlice will fail on non-ASCII domain chars.
|
||||||
name: "Unicode domain fails (no punycode conversion)",
|
name: "Unicode domain fails (no punycode conversion)",
|
||||||
domains: []string{"münchen.de"},
|
domains: []string{"münchen.de"},
|
||||||
@@ -146,9 +146,9 @@ func TestValidateDomainsList(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Valid wildcard domain",
|
name: "Invalid wildcard domain",
|
||||||
domains: []string{"*.example.com"},
|
domains: []string{"*.example.com"},
|
||||||
wantErr: false,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Wildcard with leading dot - invalid",
|
name: "Wildcard with leading dot - invalid",
|
||||||
@@ -161,12 +161,12 @@ func TestValidateDomainsList(t *testing.T) {
|
|||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Exactly maxDomains items (valid)",
|
name: "Exactly maxFQDN items (valid)",
|
||||||
domains: validDomains,
|
domains: validDomains,
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Exceeds maxDomains items",
|
name: "Exceeds maxFQDN items",
|
||||||
domains: append(validDomains, "extra.com"),
|
domains: append(validDomains, "extra.com"),
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
@@ -174,7 +174,7 @@ func TestValidateDomainsList(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
err := ValidateDomainsList(tt.domains)
|
err := ValidateFQDNsList(tt.domains)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
} else {
|
} else {
|
||||||
@@ -183,3 +183,64 @@ func TestValidateDomainsList(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsValidDomain(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
domain string
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty domain",
|
||||||
|
domain: "",
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Single valid ASCII domain",
|
||||||
|
domain: "sub.ex-ample.com",
|
||||||
|
valid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Underscores in labels",
|
||||||
|
domain: "_jabber._tcp.gmail.com",
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Unicode domain fails (no punycode conversion)",
|
||||||
|
domain: "münchen.de",
|
||||||
|
valid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid domain format - leading dash",
|
||||||
|
domain: "-example.com",
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid domain format - trailing dash",
|
||||||
|
domain: "example-.com",
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid wildcard domain",
|
||||||
|
domain: "*.example.com",
|
||||||
|
valid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Wildcard with leading dot - invalid",
|
||||||
|
domain: ".*.example.com",
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid wildcard with multiple asterisks",
|
||||||
|
domain: "a.*.example.com",
|
||||||
|
valid: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
valid := IsValidDomain(tt.domain, true, true)
|
||||||
|
assert.Equal(t, tt.valid, valid)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user