mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
move domain validation to management
This commit is contained in:
@@ -0,0 +1,60 @@
|
|||||||
|
// Package domainvalidation provides a mechanism for verifying ownership of
|
||||||
|
// a domain.
|
||||||
|
// It is intended to be used before custom domains can be assigned to reverse
|
||||||
|
// proxy services.
|
||||||
|
// Acceptable domains should be set to the known domains that reverse proxy
|
||||||
|
// servers are hosted at.
|
||||||
|
//
|
||||||
|
// After a custom domain is validated, it should be pinned to a single account
|
||||||
|
// to prevent domain abuse across accounts.
|
||||||
|
package domainvalidation
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type resolver interface {
|
||||||
|
LookupCNAME(context.Context, string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Validator struct {
|
||||||
|
resolver resolver
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewValidator initializes a validator with a specific DNS resolver.
|
||||||
|
// If a Validator is used without specifying a resolver, then it will
|
||||||
|
// use the net.DefaultResolver.
|
||||||
|
func NewValidator(resolver resolver) *Validator {
|
||||||
|
return &Validator{
|
||||||
|
resolver: resolver,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid looks up the CNAME record for the passed domain and compares it
|
||||||
|
// against the acceptable domains.
|
||||||
|
// If the returned CNAME matches any accepted domain, it will return true,
|
||||||
|
// otherwise, including in the event of a DNS error, it will return false.
|
||||||
|
// The comparison is very simple, so wildcards will not match if included
|
||||||
|
// in the acceptable domain list.
|
||||||
|
func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool {
|
||||||
|
if v.resolver == nil {
|
||||||
|
v.resolver = net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
cname, err := v.resolver.LookupCNAME(ctx, domain)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove a trailing "." from the CNAME (most people do not include the trailing "." in FQDN, so it is easier to strip this when comparing).
|
||||||
|
nakedCNAME := strings.TrimSuffix(cname, ".")
|
||||||
|
for _, domain := range accept {
|
||||||
|
// Currently, the match is a very simple string comparison.
|
||||||
|
if nakedCNAME == strings.TrimSuffix(domain, ".") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package domainvalidation_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/services/domainvalidation"
|
||||||
|
)
|
||||||
|
|
||||||
|
type resolver struct {
|
||||||
|
CNAME string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r resolver) LookupCNAME(_ context.Context, _ string) (string, error) {
|
||||||
|
return r.CNAME, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValid(t *testing.T) {
|
||||||
|
tests := map[string]struct {
|
||||||
|
resolver interface {
|
||||||
|
LookupCNAME(context.Context, string) (string, error)
|
||||||
|
}
|
||||||
|
domain string
|
||||||
|
accept []string
|
||||||
|
expect bool
|
||||||
|
}{
|
||||||
|
"match": {
|
||||||
|
resolver: resolver{"bar.example.com."}, // Including trailing "." in response.
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com"},
|
||||||
|
expect: true,
|
||||||
|
},
|
||||||
|
"no match": {
|
||||||
|
resolver: resolver{"invalid"},
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com"},
|
||||||
|
expect: false,
|
||||||
|
},
|
||||||
|
"accept trailing dot": {
|
||||||
|
resolver: resolver{"bar.example.com."},
|
||||||
|
domain: "foo.example.com",
|
||||||
|
accept: []string{"bar.example.com."}, // Including trailing "." in accept.
|
||||||
|
expect: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, test := range tests {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
validator := domainvalidation.NewValidator(test.resolver)
|
||||||
|
actual := validator.IsValid(t.Context(), test.domain, test.accept)
|
||||||
|
if test.expect != actual {
|
||||||
|
t.Errorf("Incorrect return value:\nexpect: %v\nactual: %v", test.expect, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user