move domain validation to management

This commit is contained in:
Alisdair MacLeod
2026-01-27 09:58:14 +00:00
parent 703ef29199
commit 245bbb4acf
2 changed files with 116 additions and 0 deletions

View File

@@ -0,0 +1,60 @@
// Package domainvalidation provides a mechanism for verifying ownership of
// a domain.
// It is intended to be used before custom domains can be assigned to reverse
// proxy services.
// Acceptable domains should be set to the known domains that reverse proxy
// servers are hosted at.
//
// After a custom domain is validated, it should be pinned to a single account
// to prevent domain abuse across accounts.
package domainvalidation
import (
"context"
"net"
"strings"
)
type resolver interface {
LookupCNAME(context.Context, string) (string, error)
}
type Validator struct {
resolver resolver
}
// NewValidator initializes a validator with a specific DNS resolver.
// If a Validator is used without specifying a resolver, then it will
// use the net.DefaultResolver.
func NewValidator(resolver resolver) *Validator {
return &Validator{
resolver: resolver,
}
}
// IsValid looks up the CNAME record for the passed domain and compares it
// against the acceptable domains.
// If the returned CNAME matches any accepted domain, it will return true,
// otherwise, including in the event of a DNS error, it will return false.
// The comparison is very simple, so wildcards will not match if included
// in the acceptable domain list.
func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool {
if v.resolver == nil {
v.resolver = net.DefaultResolver
}
cname, err := v.resolver.LookupCNAME(ctx, domain)
if err != nil {
return false
}
// Remove a trailing "." from the CNAME (most people do not include the trailing "." in FQDN, so it is easier to strip this when comparing).
nakedCNAME := strings.TrimSuffix(cname, ".")
for _, domain := range accept {
// Currently, the match is a very simple string comparison.
if nakedCNAME == strings.TrimSuffix(domain, ".") {
return true
}
}
return false
}

View File

@@ -0,0 +1,56 @@
package domainvalidation_test
import (
"context"
"testing"
"github.com/netbirdio/netbird/management/internals/modules/services/domainvalidation"
)
type resolver struct {
CNAME string
}
func (r resolver) LookupCNAME(_ context.Context, _ string) (string, error) {
return r.CNAME, nil
}
func TestIsValid(t *testing.T) {
tests := map[string]struct {
resolver interface {
LookupCNAME(context.Context, string) (string, error)
}
domain string
accept []string
expect bool
}{
"match": {
resolver: resolver{"bar.example.com."}, // Including trailing "." in response.
domain: "foo.example.com",
accept: []string{"bar.example.com"},
expect: true,
},
"no match": {
resolver: resolver{"invalid"},
domain: "foo.example.com",
accept: []string{"bar.example.com"},
expect: false,
},
"accept trailing dot": {
resolver: resolver{"bar.example.com."},
domain: "foo.example.com",
accept: []string{"bar.example.com."}, // Including trailing "." in accept.
expect: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
validator := domainvalidation.NewValidator(test.resolver)
actual := validator.IsValid(t.Context(), test.domain, test.accept)
if test.expect != actual {
t.Errorf("Incorrect return value:\nexpect: %v\nactual: %v", test.expect, actual)
}
})
}
}