diff --git a/management/internals/modules/reverseproxy/domain/manager.go b/management/internals/modules/reverseproxy/domain/manager.go index 21038f0f3..8d0498706 100644 --- a/management/internals/modules/reverseproxy/domain/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/url" + "strings" "github.com/netbirdio/netbird/management/server/types" log "github.com/sirupsen/logrus" @@ -57,21 +58,31 @@ func NewManager(store store, proxyURLProvider proxyURLProvider) Manager { } func (m Manager) GetDomains(ctx context.Context, accountID string) ([]*Domain, error) { - account, err := m.store.GetAccount(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("get account: %w", err) - } - free, err := m.store.ListFreeDomains(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("list free domains: %w", err) - } domains, err := m.store.ListCustomDomains(ctx, accountID) if err != nil { return nil, fmt.Errorf("list custom domains: %w", err) } var ret []*Domain - // Populate all fields correctly for custom domains that are retrieved. + + // Add connected proxy clusters as free domains. + // The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io"). + allowList := m.proxyURLAllowList() + log.WithFields(log.Fields{ + "accountID": accountID, + "proxyAllowList": allowList, + }).Debug("getting domains with proxy allow list") + + for _, cluster := range allowList { + ret = append(ret, &Domain{ + Domain: cluster, + AccountID: accountID, + Type: TypeFree, + Validated: true, + }) + } + + // Add custom domains. for _, domain := range domains { ret = append(ret, &Domain{ ID: domain.ID, @@ -82,19 +93,6 @@ func (m Manager) GetDomains(ctx context.Context, accountID string) ([]*Domain, e }) } - // Prepend each free domain with the account nonce and then add it to the domain - // array to be returned. - // This account nonce is added to free domains to prevent users being able to - // query free domain usage across accounts and simplifies tracking free domain - // usage across accounts. - for _, name := range free { - ret = append(ret, &Domain{ - Domain: account.ReverseProxyFreeDomainNonce + "." + name, - AccountID: accountID, - Type: TypeFree, - Validated: true, - }) - } return ret, nil } @@ -124,6 +122,11 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, domainID string) e } func (m Manager) ValidateDomain(accountID, domainID string) { + log.WithFields(log.Fields{ + "accountID": accountID, + "domainID": domainID, + }).Info("starting domain validation") + d, err := m.store.GetCustomDomain(context.Background(), accountID, domainID) if err != nil { log.WithFields(log.Fields{ @@ -133,12 +136,20 @@ func (m Manager) ValidateDomain(accountID, domainID string) { return } - if m.validator.IsValid(context.Background(), d.Domain, m.proxyURLAllowList()) { + allowList := m.proxyURLAllowList() + log.WithFields(log.Fields{ + "accountID": accountID, + "domainID": domainID, + "domain": d.Domain, + "proxyAllowList": allowList, + }).Info("validating domain against proxy allow list") + + if m.validator.IsValid(context.Background(), d.Domain, allowList) { log.WithFields(log.Fields{ "accountID": accountID, "domainID": domainID, "domain": d.Domain, - }).Debug("domain validated successfully") + }).Info("domain validated successfully") d.Validated = true if _, err := m.store.UpdateCustomDomain(context.Background(), accountID, d); err != nil { log.WithFields(log.Fields{ @@ -148,6 +159,13 @@ func (m Manager) ValidateDomain(accountID, domainID string) { }).WithError(err).Error("update custom domain in store") return } + } else { + log.WithFields(log.Fields{ + "accountID": accountID, + "domainID": domainID, + "domain": d.Domain, + "proxyAllowList": allowList, + }).Warn("domain validation failed - CNAME does not match any connected proxy") } } @@ -162,17 +180,60 @@ func (m Manager) proxyURLAllowList() []string { } var allowedProxyURLs []string for _, addr := range reverseProxyAddresses { - proxyUrl, err := url.Parse(addr) - if err != nil { - // TODO: log? + if addr == "" { continue } - host, _, err := net.SplitHostPort(proxyUrl.Host) - if err != nil { - // TODO: log? - host = proxyUrl.Host + host := extractHostFromAddress(addr) + if host != "" { + allowedProxyURLs = append(allowedProxyURLs, host) } - allowedProxyURLs = append(allowedProxyURLs, host) } return allowedProxyURLs } + +// extractHostFromAddress extracts the hostname from an address string. +// It handles both URL format (https://host:port) and plain hostname (host or host:port). +func extractHostFromAddress(addr string) string { + // If it looks like a URL with a scheme, parse it + if strings.Contains(addr, "://") { + proxyUrl, err := url.Parse(addr) + if err != nil { + log.WithError(err).Debugf("failed to parse proxy URL %s", addr) + return "" + } + host, _, err := net.SplitHostPort(proxyUrl.Host) + if err != nil { + return proxyUrl.Host + } + return host + } + + // Otherwise treat as hostname or host:port + host, _, err := net.SplitHostPort(addr) + if err != nil { + // No port, use as-is + return addr + } + return host +} + +// DeriveClusterFromDomain determines the proxy cluster for a given domain. +// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. +// For custom domains, the cluster is determined by looking up the CNAME target. +func (m Manager) DeriveClusterFromDomain(ctx context.Context, domain string) (string, error) { + allowList := m.proxyURLAllowList() + if len(allowList) == 0 { + return "", fmt.Errorf("no proxy clusters available") + } + + if cluster, ok := ExtractClusterFromFreeDomain(domain, allowList); ok { + return cluster, nil + } + + cluster, valid := m.validator.ValidateWithCluster(ctx, domain, allowList) + if valid { + return cluster, nil + } + + return "", fmt.Errorf("domain %s does not match any available proxy cluster", domain) +} diff --git a/management/internals/modules/reverseproxy/domain/validator.go b/management/internals/modules/reverseproxy/domain/validator.go index 7c491909b..bb9c09035 100644 --- a/management/internals/modules/reverseproxy/domain/validator.go +++ b/management/internals/modules/reverseproxy/domain/validator.go @@ -32,28 +32,69 @@ func NewValidator(resolver resolver) *Validator { // The comparison is very simple, so wildcards will not match if included // in the acceptable domain list. func (v *Validator) IsValid(ctx context.Context, domain string, accept []string) bool { + _, valid := v.ValidateWithCluster(ctx, domain, accept) + return valid +} + +// ValidateWithCluster validates a custom domain and returns the matched cluster address. +// Returns the cluster address and true if valid, or empty string and false if invalid. +func (v *Validator) ValidateWithCluster(ctx context.Context, domain string, accept []string) (string, bool) { if v.resolver == nil { v.resolver = net.DefaultResolver } - // Prepend subdomain for ownership validation because we want to check - // for the record being a wildcard ("*.example.com"), but you cannot - // look up a wildcard so we have to add a subdomain for the check. - cname, err := v.resolver.LookupCNAME(ctx, "validation."+domain) + lookupDomain := "validation." + domain + log.WithFields(log.Fields{ + "domain": domain, + "lookupDomain": lookupDomain, + "acceptList": accept, + }).Debug("looking up CNAME for domain validation") + + cname, err := v.resolver.LookupCNAME(ctx, lookupDomain) if err != nil { log.WithFields(log.Fields{ - "domain": domain, - }).WithError(err).Error("Error resolving CNAME from resolver") - return false + "domain": domain, + "lookupDomain": lookupDomain, + }).WithError(err).Warn("CNAME lookup failed for domain validation") + return "", false } - // Remove a trailing "." from the CNAME (most people do not include the trailing "." in FQDN, so it is easier to strip this when comparing). nakedCNAME := strings.TrimSuffix(cname, ".") - for _, domain := range accept { - // Currently, the match is a very simple string comparison. - if nakedCNAME == strings.TrimSuffix(domain, ".") { - return true + log.WithFields(log.Fields{ + "domain": domain, + "cname": cname, + "nakedCNAME": nakedCNAME, + "acceptList": accept, + }).Debug("CNAME lookup result for domain validation") + + for _, acceptDomain := range accept { + normalizedAccept := strings.TrimSuffix(acceptDomain, ".") + if nakedCNAME == normalizedAccept { + log.WithFields(log.Fields{ + "domain": domain, + "cname": nakedCNAME, + "cluster": acceptDomain, + }).Info("domain CNAME matched cluster") + return acceptDomain, true } } - return false + + log.WithFields(log.Fields{ + "domain": domain, + "cname": nakedCNAME, + "acceptList": accept, + }).Warn("domain CNAME does not match any accepted cluster") + return "", false +} + +// ExtractClusterFromFreeDomain extracts the cluster address from a free domain. +// Free domains have the format: .. (e.g., myapp.abc123.eu.proxy.netbird.io) +// It matches the domain suffix against available clusters and returns the matching cluster. +func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) { + for _, cluster := range availableClusters { + if strings.HasSuffix(domain, "."+cluster) { + return cluster, true + } + } + return "", false } diff --git a/management/internals/modules/reverseproxy/manager/api.go b/management/internals/modules/reverseproxy/manager/api.go index efe088d4d..990f775b4 100644 --- a/management/internals/modules/reverseproxy/manager/api.go +++ b/management/internals/modules/reverseproxy/manager/api.go @@ -20,12 +20,12 @@ type handler struct { manager reverseproxy.Manager } +// RegisterEndpoints registers all reverse proxy HTTP endpoints. func RegisterEndpoints(manager reverseproxy.Manager, domainManager domain.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { h := &handler{ manager: manager, } - // Hang domain endpoints off the main router here. domainRouter := router.PathPrefix("/reverse-proxies").Subrouter() domain.RegisterEndpoints(domainRouter, domainManager) diff --git a/management/internals/modules/reverseproxy/manager/manager.go b/management/internals/modules/reverseproxy/manager/manager.go index 6174eb2fe..cc37ef8fc 100644 --- a/management/internals/modules/reverseproxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/manager/manager.go @@ -5,6 +5,10 @@ import ( "fmt" "time" + "github.com/google/uuid" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -17,21 +21,29 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +// ClusterDeriver derives the proxy cluster from a domain. +type ClusterDeriver interface { + DeriveClusterFromDomain(ctx context.Context, domain string) (string, error) +} + type managerImpl struct { store store.Store accountManager account.Manager permissionsManager permissions.Manager proxyGRPCServer *nbgrpc.ProxyServiceServer tokenStore *nbgrpc.OneTimeTokenStore + clusterDeriver ClusterDeriver } -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, tokenStore *nbgrpc.OneTimeTokenStore) reverseproxy.Manager { +// NewManager creates a new reverse proxy manager. +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, tokenStore *nbgrpc.OneTimeTokenStore, clusterDeriver ClusterDeriver) reverseproxy.Manager { return &managerImpl{ store: store, accountManager: accountManager, permissionsManager: permissionsManager, proxyGRPCServer: proxyGRPCServer, tokenStore: tokenStore, + clusterDeriver: clusterDeriver, } } @@ -68,9 +80,17 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID return nil, status.NewPermissionDeniedError() } + var proxyCluster string + if m.clusterDeriver != nil { + proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, reverseProxy.Domain) + if err != nil { + log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxies", reverseProxy.Domain) + } + } + authConfig := reverseProxy.Auth - reverseProxy = reverseproxy.NewReverseProxy(accountID, reverseProxy.Name, reverseProxy.Domain, reverseProxy.Targets, reverseProxy.Enabled) + reverseProxy = reverseproxy.NewReverseProxy(accountID, reverseProxy.Name, reverseProxy.Domain, proxyCluster, reverseProxy.Targets, reverseProxy.Enabled) reverseProxy.Auth = authConfig @@ -111,7 +131,7 @@ func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyCreated, reverseProxy.EventMeta()) - m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig())) + m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, token), m.proxyGRPCServer.GetOIDCValidationConfig(), reverseProxy.ProxyCluster)) return reverseProxy, nil } @@ -125,28 +145,44 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID return nil, status.NewPermissionDeniedError() } + var oldCluster string + var domainChanged bool + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - // Get existing reverse proxy existingReverseProxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxy.ID) if err != nil { return err } - // Check if domain changed and if it conflicts + oldCluster = existingReverseProxy.ProxyCluster + if existingReverseProxy.Domain != reverseProxy.Domain { + domainChanged = true conflictReverseProxy, err := transaction.GetReverseProxyByDomain(ctx, accountID, reverseProxy.Domain) if err != nil { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { - return fmt.Errorf("failed to check existing reverse proxy: %w", err) + return fmt.Errorf("check existing reverse proxy: %w", err) } } if conflictReverseProxy != nil && conflictReverseProxy.ID != reverseProxy.ID { return status.Errorf(status.AlreadyExists, "reverse proxy with domain %s already exists", reverseProxy.Domain) } + + if m.clusterDeriver != nil { + newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, reverseProxy.Domain) + if err != nil { + log.WithError(err).Warnf("could not derive cluster from domain %s", reverseProxy.Domain) + } + reverseProxy.ProxyCluster = newCluster + } + } else { + reverseProxy.ProxyCluster = existingReverseProxy.ProxyCluster } + reverseProxy.Meta = existingReverseProxy.Meta + if err = transaction.UpdateReverseProxy(ctx, reverseProxy); err != nil { - return fmt.Errorf("failed to update reverse proxy: %w", err) + return fmt.Errorf("update reverse proxy: %w", err) } return nil @@ -157,7 +193,12 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyUpdated, reverseProxy.EventMeta()) - m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig())) + if domainChanged && oldCluster != reverseProxy.ProxyCluster { + m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, ""), oldCluster, m.proxyGRPCServer.GetOIDCValidationConfig()) + m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, ""), reverseProxy.ProxyCluster, m.proxyGRPCServer.GetOIDCValidationConfig()) + } else { + m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Update, ""), reverseProxy.ProxyCluster, m.proxyGRPCServer.GetOIDCValidationConfig()) + } return reverseProxy, nil } @@ -191,7 +232,7 @@ func (m *managerImpl) DeleteReverseProxy(ctx context.Context, accountID, userID, m.accountManager.StoreEvent(ctx, userID, reverseProxyID, accountID, activity.ReverseProxyDeleted, reverseProxy.EventMeta()) - m.proxyGRPCServer.SendReverseProxyUpdate(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig())) + m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, ""), reverseProxy.ProxyCluster, m.proxyGRPCServer.GetOIDCValidationConfig()) return nil } diff --git a/management/internals/modules/reverseproxy/reverseproxy.go b/management/internals/modules/reverseproxy/reverseproxy.go index 8f5323057..334255ba5 100644 --- a/management/internals/modules/reverseproxy/reverseproxy.go +++ b/management/internals/modules/reverseproxy/reverseproxy.go @@ -82,26 +82,28 @@ type ReverseProxyMeta struct { } type ReverseProxy struct { - ID string `gorm:"primaryKey"` - AccountID string `gorm:"index"` - Name string - Domain string `gorm:"index"` - Targets []Target `gorm:"serializer:json"` - Enabled bool - Auth AuthConfig `gorm:"serializer:json"` - Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"` + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index"` + Name string + Domain string `gorm:"index"` + ProxyCluster string `gorm:"index"` + Targets []Target `gorm:"serializer:json"` + Enabled bool + Auth AuthConfig `gorm:"serializer:json"` + Meta ReverseProxyMeta `gorm:"embedded;embeddedPrefix:meta_"` SessionPrivateKey string `gorm:"column:session_private_key"` SessionPublicKey string `gorm:"column:session_public_key"` } -func NewReverseProxy(accountID, name, domain string, targets []Target, enabled bool) *ReverseProxy { +func NewReverseProxy(accountID, name, domain, proxyCluster string, targets []Target, enabled bool) *ReverseProxy { return &ReverseProxy{ - ID: xid.New().String(), - AccountID: accountID, - Name: name, - Domain: domain, - Targets: targets, - Enabled: enabled, + ID: xid.New().String(), + AccountID: accountID, + Name: name, + Domain: domain, + ProxyCluster: proxyCluster, + Targets: targets, + Enabled: enabled, Meta: ReverseProxyMeta{ CreatedAt: time.Now(), Status: string(StatusPending), @@ -156,7 +158,7 @@ func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy { meta.CertificateIssuedAt = &r.Meta.CertificateIssuedAt } - return &api.ReverseProxy{ + resp := &api.ReverseProxy{ Id: r.ID, Name: r.Name, Domain: r.Domain, @@ -165,6 +167,12 @@ func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy { Auth: authConfig, Meta: meta, } + + if r.ProxyCluster != "" { + resp.ProxyCluster = &r.ProxyCluster + } + + return resp } func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping { @@ -302,7 +310,7 @@ func (r *ReverseProxy) Validate() error { } func (r *ReverseProxy) EventMeta() map[string]any { - return map[string]any{"name": r.Name, "domain": r.Domain} + return map[string]any{"name": r.Name, "domain": r.Domain, "proxy_cluster": r.ProxyCluster} } func (r *ReverseProxy) Copy() *ReverseProxy { diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index f2ff1f767..e413439a0 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -180,12 +180,13 @@ func (s *BaseServer) RecordsManager() records.Manager { func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager { return Create(s, func() reverseproxy.Manager { - return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ProxyTokenStore()) + return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ProxyTokenStore(), s.ReverseProxyDomainManager()) }) } -func (s *BaseServer) ReverseProxyDomainManager() domain.Manager { - return Create(s, func() domain.Manager { - return domain.NewManager(s.Store(), s.ReverseProxyGRPCServer()) +func (s *BaseServer) ReverseProxyDomainManager() *domain.Manager { + return Create(s, func() *domain.Manager { + m := domain.NewManager(s.Store(), s.ReverseProxyGRPCServer()) + return &m }) } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index 52c246235..8bc149e59 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -9,6 +9,7 @@ import ( "encoding/hex" "errors" "fmt" + "net" "net/url" "strings" "sync" @@ -52,6 +53,12 @@ type reverseProxyManager interface { SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error } +// ClusterInfo contains information about a proxy cluster. +type ClusterInfo struct { + Address string + ConnectedProxies int +} + // ProxyServiceServer implements the ProxyService gRPC server type ProxyServiceServer struct { proto.UnimplementedProxyServiceServer @@ -59,6 +66,9 @@ type ProxyServiceServer struct { // Map of connected proxies: proxy_id -> proxy connection connectedProxies sync.Map + // Map of cluster address -> set of proxy IDs + clusterProxies sync.Map + // Channel for broadcasting reverse proxy updates to all proxies updatesChan chan *proto.ProxyMapping @@ -127,13 +137,18 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest return status.Errorf(codes.InvalidArgument, "proxy_id is required") } - log.Infof("Proxy %s connected (version: %s, started: %s)", - proxyID, req.GetVersion(), req.GetStartedAt().AsTime()) + proxyAddress := req.GetAddress() + log.WithFields(log.Fields{ + "proxy_id": proxyID, + "address": proxyAddress, + "version": req.GetVersion(), + "started": req.GetStartedAt().AsTime(), + }).Info("Proxy connected") connCtx, cancel := context.WithCancel(ctx) conn := &proxyConnection{ proxyID: proxyID, - address: req.GetAddress(), + address: proxyAddress, stream: stream, sendChan: make(chan *proto.ProxyMapping, 100), ctx: connCtx, @@ -141,8 +156,16 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } s.connectedProxies.Store(proxyID, conn) + s.addToCluster(conn.address, proxyID) + log.WithFields(log.Fields{ + "proxy_id": proxyID, + "address": proxyAddress, + "cluster_addr": extractClusterAddr(proxyAddress), + "total_proxies": len(s.GetConnectedProxies()), + }).Info("Proxy registered in cluster") defer func() { s.connectedProxies.Delete(proxyID) + s.removeFromCluster(conn.address, proxyID) cancel() log.Infof("Proxy %s disconnected", proxyID) }() @@ -163,17 +186,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } } -// sendSnapshot sends the initial snapshot of all reverse proxies to proxy +// sendSnapshot sends the initial snapshot of reverse proxies to the connecting proxy. +// Only reverse proxies matching the proxy's cluster address are sent. func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { - reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone) // TODO: check locking strength. + reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone) if err != nil { - // TODO: something? - return fmt.Errorf("get account reverse proxies from store: %w", err) + return fmt.Errorf("get reverse proxies from store: %w", err) } + proxyClusterAddr := extractClusterAddr(conn.address) + for _, rp := range reverseProxies { if !rp.Enabled { - // We don't care about disabled reverse proxies for snapshots. + continue + } + + if rp.ProxyCluster != "" && proxyClusterAddr != "" && rp.ProxyCluster != proxyClusterAddr { continue } @@ -205,6 +233,22 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return nil } +// extractClusterAddr extracts the host from a proxy address URL. +func extractClusterAddr(addr string) string { + if addr == "" { + return "" + } + u, err := url.Parse(addr) + if err != nil { + return addr + } + host := u.Host + if h, _, err := net.SplitHostPort(host); err == nil { + return h + } + return host +} + // sender handles sending messages to proxy func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) { for { @@ -281,17 +325,106 @@ func (s *ProxyServiceServer) GetConnectedProxies() []string { func (s *ProxyServiceServer) GetConnectedProxyURLs() []string { seenUrls := make(map[string]struct{}) var urls []string + var proxyCount int s.connectedProxies.Range(func(key, value interface{}) bool { + proxyCount++ conn := value.(*proxyConnection) + log.WithFields(log.Fields{ + "proxy_id": conn.proxyID, + "address": conn.address, + }).Debug("checking connected proxy for URL") if _, seen := seenUrls[conn.address]; conn.address != "" && !seen { seenUrls[conn.address] = struct{}{} urls = append(urls, conn.address) } return true }) + log.WithFields(log.Fields{ + "total_proxies": proxyCount, + "unique_urls": len(urls), + "connected_urls": urls, + }).Debug("GetConnectedProxyURLs result") return urls } +// addToCluster registers a proxy in a cluster. +func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) { + if clusterAddr == "" { + return + } + proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) + proxySet.(*sync.Map).Store(proxyID, struct{}{}) + log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr) +} + +// removeFromCluster removes a proxy from a cluster. +func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) { + if clusterAddr == "" { + return + } + if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok { + proxySet.(*sync.Map).Delete(proxyID) + log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr) + } +} + +// SendReverseProxyUpdateToCluster sends a reverse proxy update to all proxies in a specific cluster. +// If clusterAddr is empty, broadcasts to all connected proxies (backward compatibility). +func (s *ProxyServiceServer) SendReverseProxyUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) { + if clusterAddr == "" { + s.SendReverseProxyUpdate(update) + return + } + + proxySet, ok := s.clusterProxies.Load(clusterAddr) + if !ok { + log.Debugf("No proxies connected for cluster %s", clusterAddr) + return + } + + log.Debugf("Sending reverse proxy update to cluster %s", clusterAddr) + proxySet.(*sync.Map).Range(func(key, _ interface{}) bool { + proxyID := key.(string) + if connVal, ok := s.connectedProxies.Load(proxyID); ok { + conn := connVal.(*proxyConnection) + select { + case conn.sendChan <- update: + log.Debugf("Sent reverse proxy update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) + default: + log.Warnf("Failed to send reverse proxy update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) + } + } + return true + }) +} + +// GetAvailableClusters returns information about all connected proxy clusters. +func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo { + clusterCounts := make(map[string]int) + s.clusterProxies.Range(func(key, value interface{}) bool { + clusterAddr := key.(string) + proxySet := value.(*sync.Map) + count := 0 + proxySet.Range(func(_, _ interface{}) bool { + count++ + return true + }) + if count > 0 { + clusterCounts[clusterAddr] = count + } + return true + }) + + clusters := make([]ClusterInfo, 0, len(clusterCounts)) + for addr, count := range clusterCounts { + clusters = append(clusters, ClusterInfo{ + Address: addr, + ConnectedProxies: count, + }) + } + return clusters +} + func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { proxy, err := s.reverseProxyStore.GetReverseProxyByID(ctx, store.LockingStrengthNone, req.GetAccountId(), req.GetId()) if err != nil { diff --git a/management/server/account.go b/management/server/account.go index ba5f0cffa..748290a31 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -26,7 +26,6 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" - nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" @@ -49,6 +48,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index a8fa06383..9d8dc3fe2 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -66,7 +66,7 @@ const ( ) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager domain.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *domain.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -166,7 +166,9 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks idp.AddEndpoints(accountManager, router) instance.AddEndpoints(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router) - reverseproxymanager.RegisterEndpoints(reverseProxyManager, reverseProxyDomainManager, reverseProxyAccessLogsManager, router) + if reverseProxyManager != nil && reverseProxyDomainManager != nil { + reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) + } // Register OAuth callback handler for proxy authentication if proxyGRPCServer != nil { diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index e976100cd..ce5b31cdf 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -110,7 +110,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee reverseProxyDomainManager := domain.NewManager(store, nil) var accessLogsManager accesslogs.Manager - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, reverseProxyDomainManager, accessLogsManager) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, reverseProxyDomainManager, accessLogsManager, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/idp/auth0.go b/management/server/idp/auth0.go index 0d4461e89..7d3837190 100644 --- a/management/server/idp/auth0.go +++ b/management/server/idp/auth0.go @@ -135,7 +135,7 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics) httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 - httpClient := &http.Client{ + httpClient := &http.Client{ Timeout: idpTimeout(), Transport: httpTransport, } diff --git a/management/server/idp/authentik.go b/management/server/idp/authentik.go index 0f30cc63d..ebd79b715 100644 --- a/management/server/idp/authentik.go +++ b/management/server/idp/authentik.go @@ -56,7 +56,7 @@ func NewAuthentikManager(config AuthentikClientConfig, appMetrics telemetry.AppM Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} if config.ClientID == "" { diff --git a/management/server/idp/azure.go b/management/server/idp/azure.go index e098424b5..320ca7a83 100644 --- a/management/server/idp/azure.go +++ b/management/server/idp/azure.go @@ -57,11 +57,11 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics) httpTransport := http.DefaultTransport.(*http.Transport).Clone() httpTransport.MaxIdleConns = 5 - httpClient := &http.Client{ + httpClient := &http.Client{ Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} if config.ClientID == "" { diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 6e417d394..48e4f3000 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -51,7 +51,7 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} if config.CustomerID == "" { diff --git a/management/server/idp/keycloak.go b/management/server/idp/keycloak.go index b640f7520..1cf26394f 100644 --- a/management/server/idp/keycloak.go +++ b/management/server/idp/keycloak.go @@ -66,7 +66,7 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} if config.ClientID == "" { diff --git a/management/server/idp/pocketid.go b/management/server/idp/pocketid.go index ee8e304ee..fc338b86b 100644 --- a/management/server/idp/pocketid.go +++ b/management/server/idp/pocketid.go @@ -90,7 +90,7 @@ func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMet Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} if config.ManagementEndpoint == "" { diff --git a/management/server/idp/util.go b/management/server/idp/util.go index 4310d1388..ed82fb9e3 100644 --- a/management/server/idp/util.go +++ b/management/server/idp/util.go @@ -76,7 +76,7 @@ const ( // Provides the env variable name for use with idpTimeout function idpTimeoutEnv = "NB_IDP_TIMEOUT" // Sets the defaultTimeout to 10s. - defaultTimeout = 10 * time.Second + defaultTimeout = 10 * time.Second ) // idpTimeout returns a timeout value for the IDP diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index ea0fd0aa7..320f0c131 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -167,7 +167,7 @@ func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetri Timeout: idpTimeout(), Transport: httpTransport, } - + helper := JsonParser{} hasPAT := config.PAT != "" diff --git a/management/server/permissions/modules/module.go b/management/server/permissions/modules/module.go index 1e75330eb..b7ff490d6 100644 --- a/management/server/permissions/modules/module.go +++ b/management/server/permissions/modules/module.go @@ -3,38 +3,38 @@ package modules type Module string const ( - Networks Module = "networks" - Peers Module = "peers" - RemoteJobs Module = "remote_jobs" - Groups Module = "groups" - Settings Module = "settings" - Accounts Module = "accounts" - Dns Module = "dns" - Nameservers Module = "nameservers" - Events Module = "events" - Policies Module = "policies" - Routes Module = "routes" - Users Module = "users" - SetupKeys Module = "setup_keys" - Pats Module = "pats" + Networks Module = "networks" + Peers Module = "peers" + RemoteJobs Module = "remote_jobs" + Groups Module = "groups" + Settings Module = "settings" + Accounts Module = "accounts" + Dns Module = "dns" + Nameservers Module = "nameservers" + Events Module = "events" + Policies Module = "policies" + Routes Module = "routes" + Users Module = "users" + SetupKeys Module = "setup_keys" + Pats Module = "pats" IdentityProviders Module = "identity_providers" Services Module = "services" ) var All = map[Module]struct{}{ - Networks: {}, - Peers: {}, - RemoteJobs: {}, - Groups: {}, - Settings: {}, - Accounts: {}, - Dns: {}, - Nameservers: {}, - Events: {}, - Policies: {}, - Routes: {}, - Users: {}, - SetupKeys: {}, - Pats: {}, + Networks: {}, + Peers: {}, + RemoteJobs: {}, + Groups: {}, + Settings: {}, + Accounts: {}, + Dns: {}, + Nameservers: {}, + Events: {}, + Policies: {}, + Routes: {}, + Users: {}, + SetupKeys: {}, + Pats: {}, IdentityProviders: {}, } diff --git a/management/server/util/util.go b/management/server/util/util.go index ce9759864..617484274 100644 --- a/management/server/util/util.go +++ b/management/server/util/util.go @@ -50,4 +50,3 @@ func contains[T comparableObject[T]](slice []T, element T) bool { } return false } - diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 9fb455b7a..c43e7351c 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -2845,6 +2845,10 @@ components: domain: type: string description: Domain for the reverse proxy + proxy_cluster: + type: string + description: The proxy cluster handling this reverse proxy (derived from domain) + example: "eu.proxy.netbird.io" targets: type: array items: @@ -3007,6 +3011,21 @@ components: description: Whether link auth is enabled required: - enabled + ProxyCluster: + type: object + description: A proxy cluster represents a group of proxy nodes serving the same address + properties: + address: + type: string + description: Cluster address used for CNAME targets + example: "eu.proxy.netbird.io" + connected_proxies: + type: integer + description: Number of proxy nodes connected in this cluster + example: 3 + required: + - address + - connected_proxies ReverseProxyDomainType: type: string description: Type of Reverse Proxy Domain @@ -6702,6 +6721,29 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/reverse-proxies/clusters: + get: + summary: List available proxy clusters + description: Returns a list of available proxy clusters with their connection status + tags: [ Reverse Proxy ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of proxy clusters + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProxyCluster' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/reverse-proxies/{proxyId}: get: summary: Retrieve a Reverse Proxy diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 91253de9b..0d56bcca4 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1947,6 +1947,15 @@ type ProxyAccessLog struct { UserId *string `json:"user_id,omitempty"` } +// ProxyCluster A proxy cluster represents a group of proxy nodes serving the same address +type ProxyCluster struct { + // Address Cluster address used for CNAME targets + Address string `json:"address"` + + // ConnectedProxies Number of proxy nodes connected in this cluster + ConnectedProxies int `json:"connected_proxies"` +} + // Resource defines model for Resource. type Resource struct { // Id ID of the resource @@ -1974,6 +1983,9 @@ type ReverseProxy struct { // Name Reverse proxy name Name string `json:"name"` + // ProxyCluster The proxy cluster handling this reverse proxy (derived from domain) + ProxyCluster *string `json:"proxy_cluster,omitempty"` + // Targets List of target backends for this reverse proxy Targets []ReverseProxyTarget `json:"targets"` }