diff --git a/management/internals/modules/reverseproxy/domains/api.go b/management/internals/modules/reverseproxy/domain/api.go similarity index 93% rename from management/internals/modules/reverseproxy/domains/api.go rename to management/internals/modules/reverseproxy/domain/api.go index 0824fa143..d7368d3cb 100644 --- a/management/internals/modules/reverseproxy/domains/api.go +++ b/management/internals/modules/reverseproxy/domain/api.go @@ -1,4 +1,4 @@ -package domains +package domain import ( "encoding/json" @@ -13,11 +13,13 @@ import ( ) type handler struct { - manager manager + manager Manager } -func RegisterEndpoints(router *mux.Router) { - h := &handler{} +func RegisterEndpoints(router *mux.Router, manager Manager) { + h := &handler{ + manager: manager, + } router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS") router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS") @@ -27,9 +29,9 @@ func RegisterEndpoints(router *mux.Router) { func domainTypeToApi(t domainType) api.ReverseProxyDomainType { switch t { - case domainTypeCustom: + case TypeCustom: return api.ReverseProxyDomainTypeCustom - case domainTypeFree: + case TypeFree: return api.ReverseProxyDomainTypeFree } // By default return as a "free" domain as that is more restrictive. @@ -37,7 +39,7 @@ func domainTypeToApi(t domainType) api.ReverseProxyDomainType { return api.ReverseProxyDomainTypeFree } -func domainToApi(d domain) api.ReverseProxyDomain { +func domainToApi(d *Domain) api.ReverseProxyDomain { return api.ReverseProxyDomain{ Domain: d.Domain, Id: d.ID, diff --git a/management/internals/modules/reverseproxy/domains/manager.go b/management/internals/modules/reverseproxy/domain/manager.go similarity index 63% rename from management/internals/modules/reverseproxy/domains/manager.go rename to management/internals/modules/reverseproxy/domain/manager.go index 1d6fc13c9..0a4174d4d 100644 --- a/management/internals/modules/reverseproxy/domains/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager.go @@ -1,43 +1,57 @@ -package domains +package domain import ( "context" "fmt" + "net" + + "github.com/netbirdio/netbird/management/server/types" ) type domainType string const ( - domainTypeFree domainType = "free" - domainTypeCustom domainType = "custom" + TypeFree domainType = "free" + TypeCustom domainType = "custom" ) -type domain struct { - ID string - Domain string - Type domainType +type Domain struct { + ID string `gorm:"unique;primaryKey;autoIncrement"` + Domain string `gorm:"unique"` // Domain records must be unique, this avoids domain reuse across accounts. + AccountID string `gorm:"index"` + Type domainType `gorm:"-"` Validated bool } type store interface { - GetAccountDomainNonce(ctx context.Context, accountID string) (nonce string, err error) - GetCustomDomain(ctx context.Context, accountID string, domainID string) (domain, error) + GetAccount(ctx context.Context, accountID string) (*types.Account, error) + + GetCustomDomain(ctx context.Context, accountID string, domainID string) (*Domain, error) ListFreeDomains(ctx context.Context, accountID string) ([]string, error) - ListCustomDomains(ctx context.Context, accountID string) ([]domain, error) - CreateCustomDomain(ctx context.Context, accountID string, domainName string, validated bool) (domain, error) - UpdateCustomDomain(ctx context.Context, accountID string, d domain) (domain, error) + ListCustomDomains(ctx context.Context, accountID string) ([]*Domain, error) + CreateCustomDomain(ctx context.Context, accountID string, domainName string, validated bool) (*Domain, error) + UpdateCustomDomain(ctx context.Context, accountID string, d *Domain) (*Domain, error) DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error } -type manager struct { +type Manager struct { store store validator Validator } -func (m manager) GetDomains(ctx context.Context, accountID string) ([]domain, error) { - nonce, err := m.store.GetAccountDomainNonce(ctx, accountID) +func NewManager(store store) Manager { + return Manager{ + store: store, + validator: Validator{ + resolver: net.DefaultResolver, + }, + } +} + +func (m Manager) GetDomains(ctx context.Context, accountID string) ([]*Domain, error) { + account, err := m.store.GetAccount(ctx, accountID) if err != nil { - return nil, fmt.Errorf("get account domain nonce: %w", err) + return nil, fmt.Errorf("get account: %w", err) } free, err := m.store.ListFreeDomains(ctx, accountID) if err != nil { @@ -55,16 +69,17 @@ func (m manager) GetDomains(ctx context.Context, accountID string) ([]domain, er // query free domain usage across accounts and simplifies tracking free domain // usage across accounts. for _, name := range free { - domains = append(domains, domain{ - Domain: nonce + "." + name, - Type: domainTypeFree, + domains = append(domains, &Domain{ + Domain: account.ReverseProxyFreeDomainNonce + "." + name, + AccountID: accountID, + Type: TypeFree, Validated: true, }) } return domains, nil } -func (m manager) CreateDomain(ctx context.Context, accountID, domainName string) (domain, error) { +func (m Manager) CreateDomain(ctx context.Context, accountID, domainName string) (*Domain, error) { // Attempt an initial validation; however, a failure is still acceptable for creation // because the user may not yet have configured their DNS records, or the DNS update // has not yet reached the servers that are queried by the validation resolver. @@ -83,7 +98,7 @@ func (m manager) CreateDomain(ctx context.Context, accountID, domainName string) return d, nil } -func (m manager) DeleteDomain(ctx context.Context, accountID, domainID string) error { +func (m Manager) DeleteDomain(ctx context.Context, accountID, domainID string) error { if err := m.store.DeleteCustomDomain(ctx, accountID, domainID); err != nil { // TODO: check for "no records" type error. Because that is a success condition. return fmt.Errorf("delete domain from store: %w", err) @@ -91,7 +106,7 @@ func (m manager) DeleteDomain(ctx context.Context, accountID, domainID string) e return nil } -func (m manager) ValidateDomain(accountID, domainID string) { +func (m Manager) ValidateDomain(accountID, domainID string) { d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID) if err != nil { // TODO: something? Log? diff --git a/management/internals/modules/reverseproxy/domains/validator.go b/management/internals/modules/reverseproxy/domain/validator.go similarity index 98% rename from management/internals/modules/reverseproxy/domains/validator.go rename to management/internals/modules/reverseproxy/domain/validator.go index bc58e9b07..ca5279644 100644 --- a/management/internals/modules/reverseproxy/domains/validator.go +++ b/management/internals/modules/reverseproxy/domain/validator.go @@ -1,4 +1,4 @@ -package domains +package domain import ( "context" diff --git a/management/internals/modules/reverseproxy/domains/validator_test.go b/management/internals/modules/reverseproxy/domain/validator_test.go similarity index 93% rename from management/internals/modules/reverseproxy/domains/validator_test.go rename to management/internals/modules/reverseproxy/domain/validator_test.go index eaea4185d..1f9583728 100644 --- a/management/internals/modules/reverseproxy/domains/validator_test.go +++ b/management/internals/modules/reverseproxy/domain/validator_test.go @@ -1,10 +1,10 @@ -package domains_test +package domain_test import ( "context" "testing" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domains" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" ) type resolver struct { @@ -46,7 +46,7 @@ func TestIsValid(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - validator := domains.NewValidator(test.resolver) + validator := domain.NewValidator(test.resolver) actual := validator.IsValid(t.Context(), test.domain, test.accept) if test.expect != actual { t.Errorf("Incorrect return value:\nexpect: %v\nactual: %v", test.expect, actual) diff --git a/management/internals/modules/reverseproxy/manager/api.go b/management/internals/modules/reverseproxy/manager/api.go index 736940976..3577f128b 100644 --- a/management/internals/modules/reverseproxy/manager/api.go +++ b/management/internals/modules/reverseproxy/manager/api.go @@ -7,6 +7,7 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" @@ -17,7 +18,7 @@ type handler struct { manager reverseproxy.Manager } -func RegisterEndpoints(manager reverseproxy.Manager, router *mux.Router) { +func RegisterEndpoints(manager reverseproxy.Manager, domainManager domain.Manager, router *mux.Router) { h := &handler{ manager: manager, } @@ -27,6 +28,9 @@ func RegisterEndpoints(manager reverseproxy.Manager, router *mux.Router) { router.HandleFunc("/reverse-proxies/{reverseProxyId}", h.getReverseProxy).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/{reverseProxyId}", h.updateReverseProxy).Methods("PUT", "OPTIONS") router.HandleFunc("/reverse-proxies/{reverseProxyId}", h.deleteReverseProxy).Methods("DELETE", "OPTIONS") + + // Hang domain endpoints off the main router here. + domain.RegisterEndpoints(router.PathPrefix("/reverse-proxies").Subrouter(), domainManager) } func (h *handler) getAllReverseProxies(w http.ResponseWriter, r *http.Request) { diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 7795ad804..6f13b2d4b 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -92,7 +92,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager()) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index ffeb0abdf..9ce3832ce 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" "github.com/netbirdio/netbird/management/internals/modules/zones" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" @@ -182,3 +183,9 @@ func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager { return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer()) }) } + +func (s *BaseServer) ReverseProxyDomainManager() domain.Manager { + return Create(s, func() domain.Manager { + return domain.NewManager(s.Store()) + }) +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 67a9a721f..0416a2ad4 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -13,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" idpmanager "github.com/netbirdio/netbird/management/server/idp" @@ -62,7 +63,7 @@ const ( ) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager domain.Manager) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -158,7 +159,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks idp.AddEndpoints(accountManager, router) instance.AddEndpoints(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router) - reverseproxymanager.RegisterEndpoints(reverseProxyManager, router) + reverseproxymanager.RegisterEndpoints(reverseProxyManager, reverseProxyDomainManager, router) // Mount embedded IdP handler at /oauth2 path if configured if embeddedIdpEnabled { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 731fb6cb1..01f531183 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -28,6 +28,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -127,7 +128,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, - &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.ReverseProxy{}, + &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.ReverseProxy{}, &domain.Domain{}, ) if err != nil { return nil, fmt.Errorf("auto migratePreAuto: %w", err) @@ -4688,3 +4689,78 @@ func (s *SqlStore) GetAccountReverseProxies(ctx context.Context, lockStrength Lo return proxyList, nil } + +func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) { + tx := s.db + + var customDomain *domain.Domain + result := tx.Take(&customDomain, accountAndIDQueryCondition, accountID, domainID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "reverse proxy custom domain %s not found", domainID) + } + + log.WithContext(ctx).Errorf("failed to get reverse proxy custom domain from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get reverse proxy custom domain from store") + } + + return customDomain, nil +} + +func (s *SqlStore) ListFreeDomains(ctx context.Context, accountID string) ([]string, error) { + return nil, nil +} + +func (s *SqlStore) ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) { + tx := s.db + + var domains []*domain.Domain + result := tx.Find(domains, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get reverse proxy custom domains from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get reverse proxy custom domains from store") + } + + return domains, nil +} + +func (s *SqlStore) CreateCustomDomain(ctx context.Context, accountID string, domainName string, validated bool) (*domain.Domain, error) { + newDomain := &domain.Domain{ + Domain: domainName, + AccountID: accountID, + Type: domain.TypeCustom, + Validated: validated, + } + result := s.db.Create(newDomain) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to create reverse proxy custom domain to store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to create reverse proxy custom domain to store") + } + + return newDomain, nil +} + +func (s *SqlStore) UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error) { + d.AccountID = accountID + result := s.db.Select("*").Save(d) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to update reverse proxy custom domain to store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to update reverse proxy custom domain to store") + } + + return d, nil +} + +func (s *SqlStore) DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error { + result := s.db.Delete(domain.Domain{}, accountAndIDQueryCondition, accountID, domainID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete reverse proxy custom domain from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete reverse proxy custom domain from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "reverse proxy custom domain %s not found", domainID) + } + + return nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index c7c2a1c2c..c62349acd 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/server/telemetry" @@ -248,6 +249,13 @@ type Store interface { GetReverseProxyByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.ReverseProxy, error) GetReverseProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error) GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error) + + GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) + ListFreeDomains(ctx context.Context, accountID string) ([]string, error) + ListCustomDomains(ctx context.Context, accountID string) ([]*domain.Domain, error) + CreateCustomDomain(ctx context.Context, accountID string, domainName string, validated bool) (*domain.Domain, error) + UpdateCustomDomain(ctx context.Context, accountID string, d *domain.Domain) (*domain.Domain, error) + DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error } const ( diff --git a/management/server/types/account.go b/management/server/types/account.go index a2b5140d4..959173ab8 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -108,6 +108,8 @@ type Account struct { NetworkMapCache *NetworkMapBuilder `gorm:"-"` nmapInitOnce *sync.Once `gorm:"-"` + + ReverseProxyFreeDomainNonce string } func (a *Account) InitOnce() {