[proxy] consolidate mapping update (#6072)

This commit is contained in:
Pascal Fischer
2026-05-05 18:40:42 +02:00
committed by GitHub
parent 3c28d29725
commit cfb1b3fe31
6 changed files with 559 additions and 58 deletions

View File

@@ -11,6 +11,8 @@ import (
"fmt"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -82,11 +84,40 @@ type ProxyServiceServer struct {
// Store for PKCE verifiers
pkceVerifierStore *PKCEVerifierStore
// tokenTTL is the lifetime of one-time tokens generated for proxy
// authentication. Defaults to defaultProxyTokenTTL when zero.
tokenTTL time.Duration
// snapshotBatchSize is the number of mappings per gRPC message during
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
snapshotBatchSize int
cancel context.CancelFunc
}
const pkceVerifierTTL = 10 * time.Minute
const defaultProxyTokenTTL = 5 * time.Minute
const defaultSnapshotBatchSize = 500
func snapshotBatchSizeFromEnv() int {
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
return n
}
}
return defaultSnapshotBatchSize
}
// proxyTokenTTL returns the configured token TTL or the default when unset.
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
if s.tokenTTL > 0 {
return s.tokenTTL
}
return defaultProxyTokenTTL
}
// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
@@ -110,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
snapshotBatchSize: snapshotBatchSizeFromEnv(),
cancel: cancel,
}
go s.cleanupStaleProxies(ctx)
@@ -192,11 +224,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
cancel: cancel,
}
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database with capabilities
var caps *proxy.Capabilities
if c := req.GetCapabilities(); c != nil {
@@ -209,13 +236,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
if err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
s.connectedProxies.CompareAndDelete(proxyID, conn)
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
}
cancel()
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
}
s.connectedProxies.Store(proxyID, conn)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
if err := s.sendSnapshot(ctx, conn); err != nil {
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
}
}
cancel()
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
}
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"session_id": sessionID,
@@ -241,13 +286,6 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}()
if err := s.sendSnapshot(ctx, conn); err != nil {
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
}
errChan := make(chan error, 2)
go s.sender(conn, errChan)
go s.heartbeat(connCtx, proxyRecord)
select {
@@ -290,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return err
}
// Send mappings in batches to reduce per-message gRPC overhead while
// staying well within the default 4 MB message size limit.
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
end := i + s.snapshotBatchSize
if end > len(mappings) {
end = len(mappings)
}
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: mappings[i:end],
InitialSyncComplete: end == len(mappings),
}); err != nil {
return fmt.Errorf("send snapshot batch: %w", err)
}
}
if len(mappings) == 0 {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
InitialSyncComplete: true,
}); err != nil {
return fmt.Errorf("send snapshot completion: %w", err)
}
return nil
}
for i, m := range mappings {
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{m},
InitialSyncComplete: i == len(mappings)-1,
}); err != nil {
return fmt.Errorf("send proxy mapping: %w", err)
}
}
return nil
@@ -323,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
continue
}
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
if err != nil {
log.WithFields(log.Fields{
"service": service.Name,
"account": service.AccountID,
}).WithError(err).Error("failed to generate auth token for snapshot")
continue
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
}
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
@@ -409,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
conn := value.(*proxyConnection)
resp := s.perProxyMessage(update, conn.proxyID)
if resp == nil {
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
return true
}
select {
case conn.sendChan <- resp:
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID)
conn.cancel()
}
return true
})
@@ -495,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
}
msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil {
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
conn.cancel()
continue
}
select {
case conn.sendChan <- msg:
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
conn.cancel()
}
}
}
@@ -527,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
// Returns nil if token generation fails (the proxy should be skipped).
// Returns nil if token generation fails; the caller must disconnect the
// proxy so it can resync via a fresh snapshot on reconnect.
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, mapping := range update.Mapping {
@@ -536,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
continue
}
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL())
if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil