Merge remote-tracking branch 'origin/prototype/reverse-proxy-clusters' into prototype/reverse-proxy

# Conflicts:
#	management/internals/modules/reverseproxy/manager/manager.go
#	management/internals/modules/reverseproxy/reverseproxy.go
#	management/internals/server/modules.go
#	management/internals/shared/grpc/proxy.go
#	management/server/http/handler.go
#	management/server/http/testing/testing_tools/channel/channel.go
This commit is contained in:
pascal
2026-02-05 15:19:57 +01:00
22 changed files with 466 additions and 126 deletions

View File

@@ -9,6 +9,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"net"
"net/url"
"strings"
"sync"
@@ -52,6 +53,12 @@ type reverseProxyManager interface {
SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error
}
// ClusterInfo contains information about a proxy cluster.
type ClusterInfo struct {
Address string
ConnectedProxies int
}
// ProxyServiceServer implements the ProxyService gRPC server
type ProxyServiceServer struct {
proto.UnimplementedProxyServiceServer
@@ -59,6 +66,9 @@ type ProxyServiceServer struct {
// Map of connected proxies: proxy_id -> proxy connection
connectedProxies sync.Map
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
// Channel for broadcasting reverse proxy updates to all proxies
updatesChan chan *proto.ProxyMapping
@@ -127,13 +137,18 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
return status.Errorf(codes.InvalidArgument, "proxy_id is required")
}
log.Infof("Proxy %s connected (version: %s, started: %s)",
proxyID, req.GetVersion(), req.GetStartedAt().AsTime())
proxyAddress := req.GetAddress()
log.WithFields(log.Fields{
"proxy_id": proxyID,
"address": proxyAddress,
"version": req.GetVersion(),
"started": req.GetStartedAt().AsTime(),
}).Info("Proxy connected")
connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{
proxyID: proxyID,
address: req.GetAddress(),
address: proxyAddress,
stream: stream,
sendChan: make(chan *proto.ProxyMapping, 100),
ctx: connCtx,
@@ -141,8 +156,16 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
s.connectedProxies.Store(proxyID, conn)
s.addToCluster(conn.address, proxyID)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"address": proxyAddress,
"cluster_addr": extractClusterAddr(proxyAddress),
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
s.connectedProxies.Delete(proxyID)
s.removeFromCluster(conn.address, proxyID)
cancel()
log.Infof("Proxy %s disconnected", proxyID)
}()
@@ -163,17 +186,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
}
// sendSnapshot sends the initial snapshot of all reverse proxies to proxy
// sendSnapshot sends the initial snapshot of reverse proxies to the connecting proxy.
// Only reverse proxies matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone) // TODO: check locking strength.
reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone)
if err != nil {
// TODO: something?
return fmt.Errorf("get account reverse proxies from store: %w", err)
return fmt.Errorf("get reverse proxies from store: %w", err)
}
proxyClusterAddr := extractClusterAddr(conn.address)
for _, rp := range reverseProxies {
if !rp.Enabled {
// We don't care about disabled reverse proxies for snapshots.
continue
}
if rp.ProxyCluster != "" && proxyClusterAddr != "" && rp.ProxyCluster != proxyClusterAddr {
continue
}
@@ -205,6 +233,22 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return nil
}
// extractClusterAddr extracts the host from a proxy address URL.
func extractClusterAddr(addr string) string {
if addr == "" {
return ""
}
u, err := url.Parse(addr)
if err != nil {
return addr
}
host := u.Host
if h, _, err := net.SplitHostPort(host); err == nil {
return h
}
return host
}
// sender handles sending messages to proxy
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
for {
@@ -281,17 +325,106 @@ func (s *ProxyServiceServer) GetConnectedProxies() []string {
func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
seenUrls := make(map[string]struct{})
var urls []string
var proxyCount int
s.connectedProxies.Range(func(key, value interface{}) bool {
proxyCount++
conn := value.(*proxyConnection)
log.WithFields(log.Fields{
"proxy_id": conn.proxyID,
"address": conn.address,
}).Debug("checking connected proxy for URL")
if _, seen := seenUrls[conn.address]; conn.address != "" && !seen {
seenUrls[conn.address] = struct{}{}
urls = append(urls, conn.address)
}
return true
})
log.WithFields(log.Fields{
"total_proxies": proxyCount,
"unique_urls": len(urls),
"connected_urls": urls,
}).Debug("GetConnectedProxyURLs result")
return urls
}
// addToCluster registers a proxy in a cluster.
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
}
// removeFromCluster removes a proxy from a cluster.
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
}
}
// SendReverseProxyUpdateToCluster sends a reverse proxy update to all proxies in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxies (backward compatibility).
func (s *ProxyServiceServer) SendReverseProxyUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
if clusterAddr == "" {
s.SendReverseProxyUpdate(update)
return
}
proxySet, ok := s.clusterProxies.Load(clusterAddr)
if !ok {
log.Debugf("No proxies connected for cluster %s", clusterAddr)
return
}
log.Debugf("Sending reverse proxy update to cluster %s", clusterAddr)
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxyID := key.(string)
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
select {
case conn.sendChan <- update:
log.Debugf("Sent reverse proxy update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.Warnf("Failed to send reverse proxy update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
}
return true
})
}
// GetAvailableClusters returns information about all connected proxy clusters.
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
clusterCounts := make(map[string]int)
s.clusterProxies.Range(func(key, value interface{}) bool {
clusterAddr := key.(string)
proxySet := value.(*sync.Map)
count := 0
proxySet.Range(func(_, _ interface{}) bool {
count++
return true
})
if count > 0 {
clusterCounts[clusterAddr] = count
}
return true
})
clusters := make([]ClusterInfo, 0, len(clusterCounts))
for addr, count := range clusterCounts {
clusters = append(clusters, ClusterInfo{
Address: addr,
ConnectedProxies: count,
})
}
return clusters
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
proxy, err := s.reverseProxyStore.GetReverseProxyByID(ctx, store.LockingStrengthNone, req.GetAccountId(), req.GetId())
if err != nil {