diff --git a/management/internals/modules/services/domainvalidation/validator.go b/management/internals/modules/services/domainvalidation/validator.go new file mode 100644 index 000000000..682c4417f --- /dev/null +++ b/management/internals/modules/services/domainvalidation/validator.go @@ -0,0 +1,60 @@ +// Package domainvalidation provides a mechanism for verifying ownership of +// a domain. +// It is intended to be used before custom domains can be assigned to reverse +// proxy services. +// Acceptable domains should be set to the known domains that reverse proxy +// servers are hosted at. +// +// After a custom domain is validated, it should be pinned to a single account +// to prevent domain abuse across accounts. +package domainvalidation + +import ( + "context" + "net" + "strings" +) + +type resolver interface { + LookupCNAME(context.Context, string) (string, error) +} + +type Validator struct { + resolver resolver +} + +// NewValidator initializes a validator with a specific DNS resolver. +// If a Validator is used without specifying a resolver, then it will +// use the net.DefaultResolver. +func NewValidator(resolver resolver) *Validator { + return &Validator{ + resolver: resolver, + } +} + +// IsValid looks up the CNAME record for the passed domain and compares it +// against the acceptable domains. +// If the returned CNAME matches any accepted domain, it will return true, +// otherwise, including in the event of a DNS error, it will return false. +// The comparison is very simple, so wildcards will not match if included +// in the acceptable domain list. +func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool { + if v.resolver == nil { + v.resolver = net.DefaultResolver + } + + cname, err := v.resolver.LookupCNAME(ctx, domain) + if err != nil { + return false + } + + // Remove a trailing "." from the CNAME (most people do not include the trailing "." in FQDN, so it is easier to strip this when comparing). + nakedCNAME := strings.TrimSuffix(cname, ".") + for _, domain := range accept { + // Currently, the match is a very simple string comparison. + if nakedCNAME == strings.TrimSuffix(domain, ".") { + return true + } + } + return false +} diff --git a/management/internals/modules/services/domainvalidation/validator_test.go b/management/internals/modules/services/domainvalidation/validator_test.go new file mode 100644 index 000000000..c9e7d0419 --- /dev/null +++ b/management/internals/modules/services/domainvalidation/validator_test.go @@ -0,0 +1,56 @@ +package domainvalidation_test + +import ( + "context" + "testing" + + "github.com/netbirdio/netbird/management/internals/modules/services/domainvalidation" +) + +type resolver struct { + CNAME string +} + +func (r resolver) LookupCNAME(_ context.Context, _ string) (string, error) { + return r.CNAME, nil +} + +func TestIsValid(t *testing.T) { + tests := map[string]struct { + resolver interface { + LookupCNAME(context.Context, string) (string, error) + } + domain string + accept []string + expect bool + }{ + "match": { + resolver: resolver{"bar.example.com."}, // Including trailing "." in response. + domain: "foo.example.com", + accept: []string{"bar.example.com"}, + expect: true, + }, + "no match": { + resolver: resolver{"invalid"}, + domain: "foo.example.com", + accept: []string{"bar.example.com"}, + expect: false, + }, + "accept trailing dot": { + resolver: resolver{"bar.example.com."}, + domain: "foo.example.com", + accept: []string{"bar.example.com."}, // Including trailing "." in accept. + expect: true, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + validator := domainvalidation.NewValidator(test.resolver) + actual := validator.IsValid(t.Context(), test.domain, test.accept) + if test.expect != actual { + t.Errorf("Incorrect return value:\nexpect: %v\nactual: %v", test.expect, actual) + } + }) + } +}