connect api to store and manager for domains

This commit is contained in:
Alisdair MacLeod
2026-01-27 15:43:54 +00:00
parent b7eeefc102
commit 50f42caf94
11 changed files with 153 additions and 38 deletions

View File

@@ -0,0 +1,127 @@
package domain
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager Manager
}
func RegisterEndpoints(router *mux.Router, manager Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
}
func domainTypeToApi(t domainType) api.ReverseProxyDomainType {
switch t {
case TypeCustom:
return api.ReverseProxyDomainTypeCustom
case TypeFree:
return api.ReverseProxyDomainTypeFree
}
// By default return as a "free" domain as that is more restrictive.
// TODO: is this correct?
return api.ReverseProxyDomainTypeFree
}
func domainToApi(d *Domain) api.ReverseProxyDomain {
return api.ReverseProxyDomain{
Domain: d.Domain,
Id: d.ID,
Type: domainTypeToApi(d.Type),
Validated: d.Validated,
}
}
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId)
var ret []api.ReverseProxyDomain
for _, d := range domains {
ret = append(ret, domainToApi(d))
}
util.WriteJSONObject(r.Context(), w, ret)
}
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiReverseProxyDomainsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
domain, err := h.manager.CreateDomain(r.Context(), userAuth.AccountId, req.Domain)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
}
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domainID := mux.Vars(r)["domainId"]
if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
return
}
if err := h.manager.DeleteDomain(r.Context(), userAuth.AccountId, domainID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
w.WriteHeader(http.StatusNoContent)
}
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
domainID := mux.Vars(r)["domainId"]
if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
return
}
go h.manager.ValidateDomain(userAuth.AccountId, domainID)
w.WriteHeader(http.StatusAccepted)
}

View File

@@ -0,0 +1,124 @@
package domain
import (
"context"
"fmt"
"net"
"github.com/netbirdio/netbird/management/server/types"
)
type domainType string
const (
TypeFree domainType = "free"
TypeCustom domainType = "custom"
)
type Domain struct {
ID string `gorm:"unique;primaryKey;autoIncrement"`
Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts.
AccountID string `gorm:"index"`
Type domainType `gorm:"-"`
Validated bool
}
type store interface {
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
ListCustomDomains(ctx context.Context, accountID string) ([]*Domain, error)
CreateCustomDomain(ctx context.Context, accountID string, domainName string, validated bool) (*Domain, error)
UpdateCustomDomain(ctx context.Context, accountID string, d *Domain) (*Domain, error)
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
}
type Manager struct {
store store
validator Validator
}
func NewManager(store store) Manager {
return Manager{
store: store,
validator: Validator{
resolver: net.DefaultResolver,
},
}
}
func (m Manager) GetDomains(ctx context.Context, accountID string) ([]*Domain, error) {
account, err := m.store.GetAccount(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account: %w", err)
}
free, err := m.store.ListFreeDomains(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("list free domains: %w", err)
}
domains, err := m.store.ListCustomDomains(ctx, accountID)
if err != nil {
// TODO: check for "no records" type error. Because that is a success condition.
return nil, fmt.Errorf("list custom domains: %w", err)
}
// Prepend each free domain with the account nonce and then add it to the domain
// array to be returned.
// This account nonce is added to free domains to prevent users being able to
// query free domain usage across accounts and simplifies tracking free domain
// usage across accounts.
for _, name := range free {
domains = append(domains, &Domain{
Domain: account.ReverseProxyFreeDomainNonce + "." + name,
AccountID: accountID,
Type: TypeFree,
Validated: true,
})
}
return domains, nil
}
func (m Manager) CreateDomain(ctx context.Context, accountID, domainName string) (*Domain, error) {
// Attempt an initial validation; however, a failure is still acceptable for creation
// because the user may not yet have configured their DNS records, or the DNS update
// has not yet reached the servers that are queried by the validation resolver.
var validated bool
// TODO: retrieve in use reverse proxy addresses from somewhere!
var reverseProxyAddresses []string
if m.validator.IsValid(ctx, domainName, reverseProxyAddresses) {
validated = true
}
d, err := m.store.CreateCustomDomain(ctx, accountID, domainName, validated)
if err != nil {
return d, fmt.Errorf("create domain in store: %w", err)
}
return d, nil
}
func (m Manager) DeleteDomain(ctx context.Context, accountID, domainID string) error {
if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil {
// TODO: check for "no records" type error. Because that is a success condition.
return fmt.Errorf("delete domain from store: %w", err)
}
return nil
}
func (m Manager) ValidateDomain(accountID, domainID string) {
d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID)
if err != nil {
// TODO: something? Log?
return
}
// TODO: retrieve in use reverse proxy addresses from somewhere!
var reverseProxyAddresses []string
if m.validator.IsValid(context.Background(), d.Domain, reverseProxyAddresses) {
d.Validated = true
if _, err := m.store.UpdateCustomDomain(context.Background(), accountID, d); err != nil {
// TODO: something? Log?
return
}
}
}

View File

@@ -0,0 +1,51 @@
package domain
import (
"context"
"net"
"strings"
)
type resolver interface {
LookupCNAME(context.Context, string) (string, error)
}
type Validator struct {
resolver resolver
}
// NewValidator initializes a validator with a specific DNS resolver.
// If a Validator is used without specifying a resolver, then it will
// use the net.DefaultResolver.
func NewValidator(resolver resolver) *Validator {
return &Validator{
resolver: resolver,
}
}
// IsValid looks up the CNAME record for the passed domain and compares it
// against the acceptable domains.
// If the returned CNAME matches any accepted domain, it will return true,
// otherwise, including in the event of a DNS error, it will return false.
// The comparison is very simple, so wildcards will not match if included
// in the acceptable domain list.
func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool {
if v.resolver == nil {
v.resolver = net.DefaultResolver
}
cname, err := v.resolver.LookupCNAME(ctx, domain)
if err != nil {
return false
}
// Remove a trailing "." from the CNAME (most people do not include the trailing "." in FQDN, so it is easier to strip this when comparing).
nakedCNAME := strings.TrimSuffix(cname, ".")
for _, domain := range accept {
// Currently, the match is a very simple string comparison.
if nakedCNAME == strings.TrimSuffix(domain, ".") {
return true
}
}
return false
}

View File

@@ -0,0 +1,56 @@
package domain_test
import (
"context"
"testing"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
)
type resolver struct {
CNAME string
}
func (r resolver) LookupCNAME(_ context.Context, _ string) (string, error) {
return r.CNAME, nil
}
func TestIsValid(t *testing.T) {
tests := map[string]struct {
resolver interface {
LookupCNAME(context.Context, string) (string, error)
}
domain string
accept []string
expect bool
}{
"match": {
resolver: resolver{"bar.example.com."}, // Including trailing "." in response.
domain: "foo.example.com",
accept: []string{"bar.example.com"},
expect: true,
},
"no match": {
resolver: resolver{"invalid"},
domain: "foo.example.com",
accept: []string{"bar.example.com"},
expect: false,
},
"accept trailing dot": {
resolver: resolver{"bar.example.com."},
domain: "foo.example.com",
accept: []string{"bar.example.com."}, // Including trailing "." in accept.
expect: true,
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
validator := domain.NewValidator(test.resolver)
actual := validator.IsValid(t.Context(), test.domain, test.accept)
if test.expect != actual {
t.Errorf("Incorrect return value:\nexpect: %v\nactual: %v", test.expect, actual)
}
})
}
}