mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-29 05:36:39 +00:00
add clusters logic
This commit is contained in:
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -37,6 +39,12 @@ type keyStore interface {
|
||||
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error)
|
||||
}
|
||||
|
||||
// ClusterInfo contains information about a proxy cluster.
|
||||
type ClusterInfo struct {
|
||||
Address string
|
||||
ConnectedProxies int
|
||||
}
|
||||
|
||||
// ProxyServiceServer implements the ProxyService gRPC server
|
||||
type ProxyServiceServer struct {
|
||||
proto.UnimplementedProxyServiceServer
|
||||
@@ -44,6 +52,9 @@ type ProxyServiceServer struct {
|
||||
// Map of connected proxies: proxy_id -> proxy connection
|
||||
connectedProxies sync.Map
|
||||
|
||||
// Map of cluster address -> set of proxy IDs
|
||||
clusterProxies sync.Map
|
||||
|
||||
// Channel for broadcasting reverse proxy updates to all proxies
|
||||
updatesChan chan *proto.ProxyMapping
|
||||
|
||||
@@ -115,8 +126,10 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
s.addToCluster(conn.address, proxyID)
|
||||
defer func() {
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
s.removeFromCluster(conn.address, proxyID)
|
||||
cancel()
|
||||
log.Infof("Proxy %s disconnected", proxyID)
|
||||
}()
|
||||
@@ -137,17 +150,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
}
|
||||
}
|
||||
|
||||
// sendSnapshot sends the initial snapshot of all reverse proxies to proxy
|
||||
// sendSnapshot sends the initial snapshot of reverse proxies to the connecting proxy.
|
||||
// Only reverse proxies matching the proxy's cluster address are sent.
|
||||
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
|
||||
reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone) // TODO: check locking strength.
|
||||
reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
// TODO: something?
|
||||
return fmt.Errorf("get account reverse proxies from store: %w", err)
|
||||
return fmt.Errorf("get reverse proxies from store: %w", err)
|
||||
}
|
||||
|
||||
proxyClusterAddr := extractClusterAddr(conn.address)
|
||||
|
||||
for _, rp := range reverseProxies {
|
||||
if !rp.Enabled {
|
||||
// We don't care about disabled reverse proxies for snapshots.
|
||||
continue
|
||||
}
|
||||
|
||||
if rp.ProxyCluster != "" && proxyClusterAddr != "" && rp.ProxyCluster != proxyClusterAddr {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -160,7 +178,6 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: should this even be here? We're running in a loop, and on each proxy, this will create a LOT of setup key entries that we currently have no way to remove.
|
||||
key, err := s.keyStore.CreateSetupKey(ctx,
|
||||
rp.AccountID,
|
||||
rp.Name,
|
||||
@@ -184,7 +201,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{
|
||||
rp.ToProtoMapping(
|
||||
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy.
|
||||
reverseproxy.Create,
|
||||
key.Key,
|
||||
),
|
||||
},
|
||||
@@ -197,6 +214,22 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractClusterAddr extracts the host from a proxy address URL.
|
||||
func extractClusterAddr(addr string) string {
|
||||
if addr == "" {
|
||||
return ""
|
||||
}
|
||||
u, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
host := u.Host
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
return h
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// sender handles sending messages to proxy
|
||||
func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error) {
|
||||
for {
|
||||
@@ -284,6 +317,84 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
|
||||
return urls
|
||||
}
|
||||
|
||||
// addToCluster registers a proxy in a cluster.
|
||||
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
|
||||
if clusterAddr == "" {
|
||||
return
|
||||
}
|
||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
|
||||
}
|
||||
|
||||
// removeFromCluster removes a proxy from a cluster.
|
||||
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
|
||||
if clusterAddr == "" {
|
||||
return
|
||||
}
|
||||
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
|
||||
proxySet.(*sync.Map).Delete(proxyID)
|
||||
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// SendReverseProxyUpdateToCluster sends a reverse proxy update to all proxies in a specific cluster.
|
||||
// If clusterAddr is empty, broadcasts to all connected proxies (backward compatibility).
|
||||
func (s *ProxyServiceServer) SendReverseProxyUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
|
||||
if clusterAddr == "" {
|
||||
s.SendReverseProxyUpdate(update)
|
||||
return
|
||||
}
|
||||
|
||||
proxySet, ok := s.clusterProxies.Load(clusterAddr)
|
||||
if !ok {
|
||||
log.Debugf("No proxies connected for cluster %s", clusterAddr)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Sending reverse proxy update to cluster %s", clusterAddr)
|
||||
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
||||
proxyID := key.(string)
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
select {
|
||||
case conn.sendChan <- update:
|
||||
log.Debugf("Sent reverse proxy update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.Warnf("Failed to send reverse proxy update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// GetAvailableClusters returns information about all connected proxy clusters.
|
||||
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
|
||||
clusterCounts := make(map[string]int)
|
||||
s.clusterProxies.Range(func(key, value interface{}) bool {
|
||||
clusterAddr := key.(string)
|
||||
proxySet := value.(*sync.Map)
|
||||
count := 0
|
||||
proxySet.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
if count > 0 {
|
||||
clusterCounts[clusterAddr] = count
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
clusters := make([]ClusterInfo, 0, len(clusterCounts))
|
||||
for addr, count := range clusterCounts {
|
||||
clusters = append(clusters, ClusterInfo{
|
||||
Address: addr,
|
||||
ConnectedProxies: count,
|
||||
})
|
||||
}
|
||||
return clusters
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
proxy, err := s.reverseProxyStore.GetReverseProxyByID(ctx, store.LockingStrengthNone, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user