diff --git a/management/internals/modules/reverseproxy/domain/manager.go b/management/internals/modules/reverseproxy/domain/manager.go index 21038f0f3..75c522b66 100644 --- a/management/internals/modules/reverseproxy/domain/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager.go @@ -164,15 +164,40 @@ func (m Manager) proxyURLAllowList() []string { for _, addr := range reverseProxyAddresses { proxyUrl, err := url.Parse(addr) if err != nil { - // TODO: log? + log.WithError(err).Debugf("failed to parse proxy URL %s", addr) continue } host, _, err := net.SplitHostPort(proxyUrl.Host) if err != nil { - // TODO: log? host = proxyUrl.Host } allowedProxyURLs = append(allowedProxyURLs, host) } return allowedProxyURLs } + +// DeriveClusterFromDomain determines the proxy cluster for a given domain. +// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. +// For custom domains, the cluster is determined by looking up the CNAME target. +func (m Manager) DeriveClusterFromDomain(ctx context.Context, domain string) (string, error) { + allowList := m.proxyURLAllowList() + if len(allowList) == 0 { + return "", fmt.Errorf("no proxy clusters available") + } + + if cluster, ok := ExtractClusterFromFreeDomain(domain, allowList); ok { + return cluster, nil + } + + cluster, valid := m.validator.ValidateWithCluster(ctx, domain, allowList) + if valid { + return cluster, nil + } + + return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain) +} + +// GetAvailableClusters returns a list of available proxy cluster addresses. +func (m Manager) GetAvailableClusters() []string { + return m.proxyURLAllowList() +} diff --git a/management/internals/modules/reverseproxy/domain/validator.go b/management/internals/modules/reverseproxy/domain/validator.go index 7c491909b..5616dd45d 100644 --- a/management/internals/modules/reverseproxy/domain/validator.go +++ b/management/internals/modules/reverseproxy/domain/validator.go @@ -32,28 +32,43 @@ func NewValidator(resolver resolver) *Validator { // The comparison is very simple, so wildcards will not match if included // in the acceptable domain list. func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool { + _, valid := v.ValidateWithCluster(ctx, domain, accept) + return valid +} + +// ValidateWithCluster validates a custom domain and returns the matched cluster address. +// Returns the cluster address and true if valid, or empty string and false if invalid. +func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, accept []string) (string, bool) { if v.resolver == nil { v.resolver = net.DefaultResolver } - // Prepend subdomain for ownership validation because we want to check - // for the record being a wildcard ("*.example.com"), but you cannot - // look up a wildcard so we have to add a subdomain for the check. cname, err := v.resolver.LookupCNAME(ctx, "validation."+domain) if err != nil { log.WithFields(log.Fields{ "domain": domain, }).WithError(err).Error("Error resolving CNAME from resolver") - return false + return "", false } - // Remove a trailing "." from the CNAME (most people do not include the trailing "." in FQDN, so it is easier to strip this when comparing). nakedCNAME := strings.TrimSuffix(cname, ".") - for _, domain := range accept { - // Currently, the match is a very simple string comparison. - if nakedCNAME == strings.TrimSuffix(domain, ".") { - return true + for _, acceptDomain := range accept { + normalizedAccept := strings.TrimSuffix(acceptDomain, ".") + if nakedCNAME == normalizedAccept { + return acceptDomain, true } } - return false + return "", false +} + +// ExtractClusterFromFreeDomain extracts the cluster address from a free domain. +// Free domains have the format: .. (e.g., myapp.abc123.eu.proxy.netbird.io) +// It matches the domain suffix against available clusters and returns the matching cluster. +func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) { + for _, cluster := range availableClusters { + if strings.HasSuffix(domain, "."+cluster) { + return cluster, true + } + } + return "", false } diff --git a/management/internals/modules/reverseproxy/manager/api.go b/management/internals/modules/reverseproxy/manager/api.go index efe088d4d..381f296e0 100644 --- a/management/internals/modules/reverseproxy/manager/api.go +++ b/management/internals/modules/reverseproxy/manager/api.go @@ -10,22 +10,29 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) -type handler struct { - manager reverseproxy.Manager +type clusterProvider interface { + GetAvailableClusters() []nbgrpc.ClusterInfo } -func RegisterEndpoints(manager reverseproxy.Manager, domainManager domain.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { +type handler struct { + manager reverseproxy.Manager + clusterProvider clusterProvider +} + +// RegisterEndpoints registers all reverse proxy HTTP endpoints. +func RegisterEndpoints(manager reverseproxy.Manager, domainManager domain.Manager, accessLogsManager accesslogs.Manager, clusterProvider clusterProvider, router *mux.Router) { h := &handler{ - manager: manager, + manager: manager, + clusterProvider: clusterProvider, } - // Hang domain endpoints off the main router here. domainRouter := router.PathPrefix("/reverse-proxies").Subrouter() domain.RegisterEndpoints(domainRouter, domainManager) @@ -33,6 +40,7 @@ func RegisterEndpoints(manager reverseproxy.Manager, domainManager domain.Manage router.HandleFunc("/reverse-proxies", h.getAllReverseProxies).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies", h.createReverseProxy).Methods("POST", "OPTIONS") + router.HandleFunc("/reverse-proxies/clusters", h.getAvailableClusters).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/{proxyId}", h.getReverseProxy).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/{proxyId}", h.updateReverseProxy).Methods("PUT", "OPTIONS") router.HandleFunc("/reverse-proxies/{proxyId}", h.deleteReverseProxy).Methods("DELETE", "OPTIONS") @@ -168,3 +176,22 @@ func (h *handler) deleteReverseProxy(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } + +func (h *handler) getAvailableClusters(w http.ResponseWriter, r *http.Request) { + _, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + clusters := h.clusterProvider.GetAvailableClusters() + apiClusters := make([]api.ProxyCluster, 0, len(clusters)) + for _, c := range clusters { + apiClusters = append(apiClusters, api.ProxyCluster{ + Address: c.Address, + ConnectedProxies: c.ConnectedProxies, + }) + } + + util.WriteJSONObject(r.Context(), w, apiClusters) +} diff --git a/management/internals/modules/reverseproxy/manager/manager.go b/management/internals/modules/reverseproxy/manager/manager.go index 0e9b44983..0bafc80fa 100644 --- a/management/internals/modules/reverseproxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/manager/manager.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/rs/xid" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -20,19 +21,27 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +// ClusterDeriver derives the proxy cluster from a domain. +type ClusterDeriver interface { + DeriveClusterFromDomain(ctx context.Context, domain string) (string, error) +} + type managerImpl struct { store store.Store accountManager account.Manager permissionsManager permissions.Manager proxyGRPCServer *nbgrpc.ProxyServiceServer + clusterDeriver ClusterDeriver } -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer) reverseproxy.Manager { +// NewManager creates a new reverse proxy manager. +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager { return &managerImpl{ store: store, accountManager: accountManager, permissionsManager: permissionsManager, proxyGRPCServer: proxyGRPCServer, + clusterDeriver: clusterDeriver, } } @@ -69,9 +78,17 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID return nil, status.NewPermissionDeniedError() } + var proxyCluster string + if m.clusterDeriver != nil { + proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, reverseProxy.Domain) + if err != nil { + log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxies", reverseProxy.Domain) + } + } + authConfig := reverseProxy.Auth - reverseProxy = reverseproxy.NewReverseProxy(accountID, reverseProxy.Name, reverseProxy.Domain, reverseProxy.Targets, reverseProxy.Enabled) + reverseProxy = reverseproxy.NewReverseProxy(accountID, reverseProxy.Name, reverseProxy.Domain, proxyCluster, reverseProxy.Targets, reverseProxy.Enabled) reverseProxy.Auth = authConfig @@ -146,7 +163,7 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID return nil, fmt.Errorf("failed to create setup key for reverse proxy: %w", err) } - m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Create, key.Key)) + m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, key.Key), reverseProxy.ProxyCluster) return reverseProxy, nil } @@ -160,28 +177,44 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID return nil, status.NewPermissionDeniedError() } + var oldCluster string + var domainChanged bool + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - // Get existing reverse proxy existingReverseProxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxy.ID) if err != nil { return err } - // Check if domain changed and if it conflicts + oldCluster = existingReverseProxy.ProxyCluster + if existingReverseProxy.Domain != reverseProxy.Domain { + domainChanged = true conflictReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain) if err != nil { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { - return fmt.Errorf("failed to check existing reverse proxy: %w", err) + return fmt.Errorf("check existing reverse proxy: %w", err) } } if conflictReverseProxy != nil && conflictReverseProxy.ID != reverseProxy.ID { return status.Errorf(status.AlreadyExists, "reverse proxy with domain %s already exists", reverseProxy.Domain) } + + if m.clusterDeriver != nil { + newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, reverseProxy.Domain) + if err != nil { + log.WithError(err).Warnf("could not derive cluster from domain %s", reverseProxy.Domain) + } + reverseProxy.ProxyCluster = newCluster + } + } else { + reverseProxy.ProxyCluster = existingReverseProxy.ProxyCluster } + reverseProxy.Meta = existingReverseProxy.Meta + if err = transaction.UpdateReverseProxy(ctx, reverseProxy); err != nil { - return fmt.Errorf("failed to update reverse proxy: %w", err) + return fmt.Errorf("update reverse proxy: %w", err) } return nil @@ -192,7 +225,12 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyUpdated, reverseProxy.EventMeta()) - m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Update, "")) + if domainChanged && oldCluster != reverseProxy.ProxyCluster { + m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, ""), oldCluster) + m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, ""), reverseProxy.ProxyCluster) + } else { + m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Update, ""), reverseProxy.ProxyCluster) + } return reverseProxy, nil } @@ -226,7 +264,7 @@ func (m *managerImpl) DeleteReverseProxy(ctx context.Context, accountID, userID, m.accountManager.StoreEvent(ctx, userID, reverseProxyID, accountID, activity.ReverseProxyDeleted, reverseProxy.EventMeta()) - m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Delete, "")) + m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, ""), reverseProxy.ProxyCluster) return nil } diff --git a/management/internals/modules/reverseproxy/reverseproxy.go b/management/internals/modules/reverseproxy/reverseproxy.go index a7aebe938..edb541410 100644 --- a/management/internals/modules/reverseproxy/reverseproxy.go +++ b/management/internals/modules/reverseproxy/reverseproxy.go @@ -76,24 +76,26 @@ type ReverseProxyMeta struct { } type ReverseProxy struct { - ID string `gorm:"primaryKey"` - AccountID string `gorm:"index"` - Name string - Domain string `gorm:"index"` - Targets []Target `gorm:"serializer:json"` - Enabled bool - Auth AuthConfig `gorm:"serializer:json"` - Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"` + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index"` + Name string + Domain string `gorm:"index"` + ProxyCluster string `gorm:"index"` + Targets []Target `gorm:"serializer:json"` + Enabled bool + Auth AuthConfig `gorm:"serializer:json"` + Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"` } -func NewReverseProxy(accountID, name, domain string, targets []Target, enabled bool) *ReverseProxy { +func NewReverseProxy(accountID, name, domain, proxyCluster string, targets []Target, enabled bool) *ReverseProxy { return &ReverseProxy{ - ID: xid.New().String(), - AccountID: accountID, - Name: name, - Domain: domain, - Targets: targets, - Enabled: enabled, + ID: xid.New().String(), + AccountID: accountID, + Name: name, + Domain: domain, + ProxyCluster: proxyCluster, + Targets: targets, + Enabled: enabled, Meta: ReverseProxyMeta{ CreatedAt: time.Now(), Status: string(StatusPending), @@ -154,7 +156,7 @@ func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy { meta.CertificateIssuedAt = &r.Meta.CertificateIssuedAt } - return &api.ReverseProxy{ + resp := &api.ReverseProxy{ Id: r.ID, Name: r.Name, Domain: r.Domain, @@ -163,6 +165,12 @@ func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy { Auth: authConfig, Meta: meta, } + + if r.ProxyCluster != "" { + resp.ProxyCluster = &r.ProxyCluster + } + + return resp } func (r *ReverseProxy) ToProtoMapping(operation Operation, setupKey string) *proto.ProxyMapping { @@ -310,5 +318,5 @@ func (r *ReverseProxy) Validate() error { } func (r *ReverseProxy) EventMeta() map[string]any { - return map[string]any{"name": r.Name, "domain": r.Domain} + return map[string]any{"name": r.Name, "domain": r.Domain, "proxy_cluster": r.ProxyCluster} } diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index de86ad22a..3b25ee7ae 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager()) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 0d1b44be2..02a4330f7 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -180,12 +180,14 @@ func (s *BaseServer) RecordsManager() records.Manager { func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager { return Create(s, func() reverseproxy.Manager { - return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer()) + domainMgr := s.ReverseProxyDomainManager() + return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), domainMgr) }) } -func (s *BaseServer) ReverseProxyDomainManager() domain.Manager { - return Create(s, func() domain.Manager { - return domain.NewManager(s.Store(), s.ReverseProxyGRPCServer()) +func (s *BaseServer) ReverseProxyDomainManager() *domain.Manager { + return Create(s, func() *domain.Manager { + m := domain.NewManager(s.Store(), s.ReverseProxyGRPCServer()) + return &m }) } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 10bc14ac3..a97860528 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -4,6 +4,8 @@ import ( "context" "crypto/subtle" "fmt" + "net" + "net/url" "sync" "time" @@ -37,6 +39,12 @@ type keyStore interface { CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) } +// ClusterInfo contains information about a proxy cluster. +type ClusterInfo struct { + Address string + ConnectedProxies int +} + // ProxyServiceServer implements the ProxyService gRPC server type ProxyServiceServer struct { proto.UnimplementedProxyServiceServer @@ -44,6 +52,9 @@ type ProxyServiceServer struct { // Map of connected proxies: proxy_id -> proxy connection connectedProxies sync.Map + // Map of cluster address -> set of proxy IDs + clusterProxies sync.Map + // Channel for broadcasting reverse proxy updates to all proxies updatesChan chan *proto.ProxyMapping @@ -115,8 +126,10 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } s.connectedProxies.Store(proxyID, conn) + s.addToCluster(conn.address, proxyID) defer func() { s.connectedProxies.Delete(proxyID) + s.removeFromCluster(conn.address, proxyID) cancel() log.Infof("Proxy %s disconnected", proxyID) }() @@ -137,17 +150,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } } -// sendSnapshot sends the initial snapshot of all reverse proxies to proxy +// sendSnapshot sends the initial snapshot of reverse proxies to the connecting proxy. +// Only reverse proxies matching the proxy's cluster address are sent. func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { - reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone) // TODO: check locking strength. + reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone) if err != nil { - // TODO: something? - return fmt.Errorf("get account reverse proxies from store: %w", err) + return fmt.Errorf("get reverse proxies from store: %w", err) } + proxyClusterAddr := extractClusterAddr(conn.address) + for _, rp := range reverseProxies { if !rp.Enabled { - // We don't care about disabled reverse proxies for snapshots. + continue + } + + if rp.ProxyCluster != "" && proxyClusterAddr != "" && rp.ProxyCluster != proxyClusterAddr { continue } @@ -160,7 +178,6 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec continue } - // TODO: should this even be here? We're running in a loop, and on each proxy, this will create a LOT of setup key entries that we currently have no way to remove. key, err := s.keyStore.CreateSetupKey(ctx, rp.AccountID, rp.Name, @@ -184,7 +201,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ Mapping: []*proto.ProxyMapping{ rp.ToProtoMapping( - reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy. + reverseproxy.Create, key.Key, ), }, @@ -197,6 +214,22 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return nil } +// extractClusterAddr extracts the host from a proxy address URL. +func extractClusterAddr(addr string) string { + if addr == "" { + return "" + } + u, err := url.Parse(addr) + if err != nil { + return addr + } + host := u.Host + if h, _, err := net.SplitHostPort(host); err == nil { + return h + } + return host +} + // sender handles sending messages to proxy func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) { for { @@ -284,6 +317,84 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string { return urls } +// addToCluster registers a proxy in a cluster. +func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) { + if clusterAddr == "" { + return + } + proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) + proxySet.(*sync.Map).Store(proxyID, struct{}{}) + log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr) +} + +// removeFromCluster removes a proxy from a cluster. +func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) { + if clusterAddr == "" { + return + } + if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok { + proxySet.(*sync.Map).Delete(proxyID) + log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr) + } +} + +// SendReverseProxyUpdateToCluster sends a reverse proxy update to all proxies in a specific cluster. +// If clusterAddr is empty, broadcasts to all connected proxies (backward compatibility). +func (s *ProxyServiceServer) SendReverseProxyUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) { + if clusterAddr == "" { + s.SendReverseProxyUpdate(update) + return + } + + proxySet, ok := s.clusterProxies.Load(clusterAddr) + if !ok { + log.Debugf("No proxies connected for cluster %s", clusterAddr) + return + } + + log.Debugf("Sending reverse proxy update to cluster %s", clusterAddr) + proxySet.(*sync.Map).Range(func(key, _ interface{}) bool { + proxyID := key.(string) + if connVal, ok := s.connectedProxies.Load(proxyID); ok { + conn := connVal.(*proxyConnection) + select { + case conn.sendChan <- update: + log.Debugf("Sent reverse proxy update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) + default: + log.Warnf("Failed to send reverse proxy update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) + } + } + return true + }) +} + +// GetAvailableClusters returns information about all connected proxy clusters. +func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo { + clusterCounts := make(map[string]int) + s.clusterProxies.Range(func(key, value interface{}) bool { + clusterAddr := key.(string) + proxySet := value.(*sync.Map) + count := 0 + proxySet.Range(func(_, _ interface{}) bool { + count++ + return true + }) + if count > 0 { + clusterCounts[clusterAddr] = count + } + return true + }) + + clusters := make([]ClusterInfo, 0, len(clusterCounts)) + for addr, count := range clusterCounts { + clusters = append(clusters, ClusterInfo{ + Address: addr, + ConnectedProxies: count, + }) + } + return clusters +} + func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { proxy, err := s.reverseProxyStore.GetReverseProxyByID(ctx, store.LockingStrengthNone, req.GetAccountId(), req.GetId()) if err != nil { diff --git a/management/server/account.go b/management/server/account.go index ba5f0cffa..748290a31 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -26,7 +26,6 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" - nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" @@ -49,6 +48,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 2aed06fa4..92f349160 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" idpmanager "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/management-integrations/integrations" @@ -63,8 +64,13 @@ const ( rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM" ) +// clusterProvider provides access to available proxy clusters. +type clusterProvider interface { + GetAvailableClusters() []nbgrpc.ClusterInfo +} + // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager domain.Manager, reverseProxyAccessLogsManager accesslogs.Manager) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *domain.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyClusterProvider clusterProvider) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -160,7 +166,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks idp.AddEndpoints(accountManager, router) instance.AddEndpoints(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router) - reverseproxymanager.RegisterEndpoints(reverseProxyManager, reverseProxyDomainManager, reverseProxyAccessLogsManager, router) + if reverseProxyManager != nil && reverseProxyDomainManager != nil { + reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, proxyClusterProvider, router) + } // Mount embedded IdP handler at /oauth2 path if configured if embeddedIdpEnabled { diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 9339c3541..bdc68e85f 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -102,7 +102,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index 0d4461e89..7d3837190 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -135,7 +135,7 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 - httpClient := &http.Client{ + httpClient := &http.Client{ Timeout: idpTimeout(), Transport: httpTransport, } diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index 0f30cc63d..ebd79b715 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -56,7 +56,7 @@ func NewAuthentikManager(config AuthentikClientConfig, appMetrics telemetry.AppM Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} if config.ClientID == "" { diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index e098424b5..320ca7a83 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -57,11 +57,11 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 - httpClient := &http.Client{ + httpClient := &http.Client{ Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} if config.ClientID == "" { diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 6e417d394..48e4f3000 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -51,7 +51,7 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} if config.CustomerID == "" { diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go index b640f7520..1cf26394f 100644 --- a/management/server/idp/keycloak.go +++ b/management/server/idp/keycloak.go @@ -66,7 +66,7 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} if config.ClientID == "" { diff --git a/management/server/idp/pocketid.go b/management/server/idp/pocketid.go index ee8e304ee..fc338b86b 100644 --- a/management/server/idp/pocketid.go +++ b/management/server/idp/pocketid.go @@ -90,7 +90,7 @@ func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMet Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} if config.ManagementEndpoint == "" { diff --git a/management/server/idp/util.go b/management/server/idp/util.go index 4310d1388..ed82fb9e3 100644 --- a/management/server/idp/util.go +++ b/management/server/idp/util.go @@ -76,7 +76,7 @@ const ( // Provides the env variable name for use with idpTimeout function idpTimeoutEnv = "NB_IDP_TIMEOUT" // Sets the defaultTimeout to 10s. - defaultTimeout = 10 * time.Second + defaultTimeout = 10 * time.Second ) // idpTimeout returns a timeout value for the IDP diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index ea0fd0aa7..320f0c131 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -167,7 +167,7 @@ func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetri Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} hasPAT := config.PAT != "" diff --git a/management/server/permissions/modules/module.go b/management/server/permissions/modules/module.go index 1e75330eb..b7ff490d6 100644 --- a/management/server/permissions/modules/module.go +++ b/management/server/permissions/modules/module.go @@ -3,38 +3,38 @@ package modules type Module string const ( - Networks Module = "networks" - Peers Module = "peers" - RemoteJobs Module = "remote_jobs" - Groups Module = "groups" - Settings Module = "settings" - Accounts Module = "accounts" - Dns Module = "dns" - Nameservers Module = "nameservers" - Events Module = "events" - Policies Module = "policies" - Routes Module = "routes" - Users Module = "users" - SetupKeys Module = "setup_keys" - Pats Module = "pats" + Networks Module = "networks" + Peers Module = "peers" + RemoteJobs Module = "remote_jobs" + Groups Module = "groups" + Settings Module = "settings" + Accounts Module = "accounts" + Dns Module = "dns" + Nameservers Module = "nameservers" + Events Module = "events" + Policies Module = "policies" + Routes Module = "routes" + Users Module = "users" + SetupKeys Module = "setup_keys" + Pats Module = "pats" IdentityProviders Module = "identity_providers" Services Module = "services" ) var All = map[Module]struct{}{ - Networks: {}, - Peers: {}, - RemoteJobs: {}, - Groups: {}, - Settings: {}, - Accounts: {}, - Dns: {}, - Nameservers: {}, - Events: {}, - Policies: {}, - Routes: {}, - Users: {}, - SetupKeys: {}, - Pats: {}, + Networks: {}, + Peers: {}, + RemoteJobs: {}, + Groups: {}, + Settings: {}, + Accounts: {}, + Dns: {}, + Nameservers: {}, + Events: {}, + Policies: {}, + Routes: {}, + Users: {}, + SetupKeys: {}, + Pats: {}, IdentityProviders: {}, } diff --git a/management/server/util/util.go b/management/server/util/util.go index ce9759864..617484274 100644 --- a/management/server/util/util.go +++ b/management/server/util/util.go @@ -50,4 +50,3 @@ func contains[T comparableObject[T]](slice []T, element T) bool { } return false } - diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 9fb455b7a..c43e7351c 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -2845,6 +2845,10 @@ components: domain: type: string description: Domain for the reverse proxy + proxy_cluster: + type: string + description: The proxy cluster handling this reverse proxy (derived from domain) + example: "eu.proxy.netbird.io" targets: type: array items: @@ -3007,6 +3011,21 @@ components: description: Whether link auth is enabled required: - enabled + ProxyCluster: + type: object + description: A proxy cluster represents a group of proxy nodes serving the same address + properties: + address: + type: string + description: Cluster address used for CNAME targets + example: "eu.proxy.netbird.io" + connected_proxies: + type: integer + description: Number of proxy nodes connected in this cluster + example: 3 + required: + - address + - connected_proxies ReverseProxyDomainType: type: string description: Type of Reverse Proxy Domain @@ -6702,6 +6721,29 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/clusters: + get: + summary: List available proxy clusters + description: Returns a list of available proxy clusters with their connection status + tags: [ Reverse Proxy ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of proxy clusters + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProxyCluster' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/reverse-proxies/{proxyId}: get: summary: Retrieve a Reverse Proxy diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 91253de9b..0d56bcca4 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1947,6 +1947,15 @@ type ProxyAccessLog struct { UserId *string `json:"user_id,omitempty"` } +// ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address +type ProxyCluster struct { + // Address Cluster address used for CNAME targets + Address string `json:"address"` + + // ConnectedProxies Number of proxy nodes connected in this cluster + ConnectedProxies int `json:"connected_proxies"` +} + // Resource defines model for Resource. type Resource struct { // Id ID of the resource @@ -1974,6 +1983,9 @@ type ReverseProxy struct { // Name Reverse proxy name Name string `json:"name"` + // ProxyCluster The proxy cluster handling this reverse proxy (derived from domain) + ProxyCluster *string `json:"proxy_cluster,omitempty"` + // Targets List of target backends for this reverse proxy Targets []ReverseProxyTarget `json:"targets"` }