mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 03:36:41 +00:00
Merge remote-tracking branch 'origin/prototype/reverse-proxy-clusters' into prototype/reverse-proxy
# Conflicts: # management/internals/modules/reverseproxy/manager/manager.go # management/internals/modules/reverseproxy/reverseproxy.go # management/internals/server/modules.go # management/internals/shared/grpc/proxy.go # management/server/http/handler.go # management/server/http/testing/testing_tools/channel/channel.go
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -52,6 +53,12 @@ type reverseProxyManager interface {
|
||||
SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error
|
||||
}
|
||||
|
||||
// ClusterInfo contains information about a proxy cluster.
|
||||
type ClusterInfo struct {
|
||||
Address string
|
||||
ConnectedProxies int
|
||||
}
|
||||
|
||||
// ProxyServiceServer implements the ProxyService gRPC server
|
||||
type ProxyServiceServer struct {
|
||||
proto.UnimplementedProxyServiceServer
|
||||
@@ -59,6 +66,9 @@ type ProxyServiceServer struct {
|
||||
// Map of connected proxies: proxy_id -> proxy connection
|
||||
connectedProxies sync.Map
|
||||
|
||||
// Map of cluster address -> set of proxy IDs
|
||||
clusterProxies sync.Map
|
||||
|
||||
// Channel for broadcasting reverse proxy updates to all proxies
|
||||
updatesChan chan *proto.ProxyMapping
|
||||
|
||||
@@ -127,13 +137,18 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
return status.Errorf(codes.InvalidArgument, "proxy_id is required")
|
||||
}
|
||||
|
||||
log.Infof("Proxy %s connected (version: %s, started: %s)",
|
||||
proxyID, req.GetVersion(), req.GetStartedAt().AsTime())
|
||||
proxyAddress := req.GetAddress()
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"address": proxyAddress,
|
||||
"version": req.GetVersion(),
|
||||
"started": req.GetStartedAt().AsTime(),
|
||||
}).Info("Proxy connected")
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: req.GetAddress(),
|
||||
address: proxyAddress,
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.ProxyMapping, 100),
|
||||
ctx: connCtx,
|
||||
@@ -141,8 +156,16 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
s.addToCluster(conn.address, proxyID)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": extractClusterAddr(proxyAddress),
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
s.removeFromCluster(conn.address, proxyID)
|
||||
cancel()
|
||||
log.Infof("Proxy %s disconnected", proxyID)
|
||||
}()
|
||||
@@ -163,17 +186,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
}
|
||||
|
||||
// sendSnapshot sends the initial snapshot of all reverse proxies to proxy
|
||||
// sendSnapshot sends the initial snapshot of reverse proxies to the connecting proxy.
|
||||
// Only reverse proxies matching the proxy's cluster address are sent.
|
||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||
reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone) // TODO: check locking strength.
|
||||
reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
// TODO: something?
|
||||
return fmt.Errorf("get account reverse proxies from store: %w", err)
|
||||
return fmt.Errorf("get reverse proxies from store: %w", err)
|
||||
}
|
||||
|
||||
proxyClusterAddr := extractClusterAddr(conn.address)
|
||||
|
||||
for _, rp := range reverseProxies {
|
||||
if !rp.Enabled {
|
||||
// We don't care about disabled reverse proxies for snapshots.
|
||||
continue
|
||||
}
|
||||
|
||||
if rp.ProxyCluster != "" && proxyClusterAddr != "" && rp.ProxyCluster != proxyClusterAddr {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -205,6 +233,22 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractClusterAddr extracts the host from a proxy address URL.
|
||||
func extractClusterAddr(addr string) string {
|
||||
if addr == "" {
|
||||
return ""
|
||||
}
|
||||
u, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
host := u.Host
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
return h
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// sender handles sending messages to proxy
|
||||
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
|
||||
for {
|
||||
@@ -281,17 +325,106 @@ func (s *ProxyServiceServer) GetConnectedProxies() []string {
|
||||
func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
|
||||
seenUrls := make(map[string]struct{})
|
||||
var urls []string
|
||||
var proxyCount int
|
||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||
proxyCount++
|
||||
conn := value.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": conn.proxyID,
|
||||
"address": conn.address,
|
||||
}).Debug("checking connected proxy for URL")
|
||||
if _, seen := seenUrls[conn.address]; conn.address != "" && !seen {
|
||||
seenUrls[conn.address] = struct{}{}
|
||||
urls = append(urls, conn.address)
|
||||
}
|
||||
return true
|
||||
})
|
||||
log.WithFields(log.Fields{
|
||||
"total_proxies": proxyCount,
|
||||
"unique_urls": len(urls),
|
||||
"connected_urls": urls,
|
||||
}).Debug("GetConnectedProxyURLs result")
|
||||
return urls
|
||||
}
|
||||
|
||||
// addToCluster registers a proxy in a cluster.
|
||||
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
|
||||
if clusterAddr == "" {
|
||||
return
|
||||
}
|
||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
|
||||
}
|
||||
|
||||
// removeFromCluster removes a proxy from a cluster.
|
||||
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
|
||||
if clusterAddr == "" {
|
||||
return
|
||||
}
|
||||
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
|
||||
proxySet.(*sync.Map).Delete(proxyID)
|
||||
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// SendReverseProxyUpdateToCluster sends a reverse proxy update to all proxies in a specific cluster.
|
||||
// If clusterAddr is empty, broadcasts to all connected proxies (backward compatibility).
|
||||
func (s *ProxyServiceServer) SendReverseProxyUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
|
||||
if clusterAddr == "" {
|
||||
s.SendReverseProxyUpdate(update)
|
||||
return
|
||||
}
|
||||
|
||||
proxySet, ok := s.clusterProxies.Load(clusterAddr)
|
||||
if !ok {
|
||||
log.Debugf("No proxies connected for cluster %s", clusterAddr)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Sending reverse proxy update to cluster %s", clusterAddr)
|
||||
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
||||
proxyID := key.(string)
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
select {
|
||||
case conn.sendChan <- update:
|
||||
log.Debugf("Sent reverse proxy update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.Warnf("Failed to send reverse proxy update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// GetAvailableClusters returns information about all connected proxy clusters.
|
||||
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
|
||||
clusterCounts := make(map[string]int)
|
||||
s.clusterProxies.Range(func(key, value interface{}) bool {
|
||||
clusterAddr := key.(string)
|
||||
proxySet := value.(*sync.Map)
|
||||
count := 0
|
||||
proxySet.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
clusterCounts[clusterAddr] = count
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
clusters := make([]ClusterInfo, 0, len(clusterCounts))
|
||||
for addr, count := range clusterCounts {
|
||||
clusters = append(clusters, ClusterInfo{
|
||||
Address: addr,
|
||||
ConnectedProxies: count,
|
||||
})
|
||||
}
|
||||
return clusters
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
proxy, err := s.reverseProxyStore.GetReverseProxyByID(ctx, store.LockingStrengthNone, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user