[management] fix proxy reconnect (#6063)

This commit is contained in:
Pascal Fischer
2026-05-04 20:43:25 +02:00
committed by GitHub
parent 77a0992dc2
commit 97db824929
9 changed files with 199 additions and 77 deletions

View File

@@ -16,6 +16,7 @@ import (
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"google.golang.org/grpc/codes"
@@ -89,6 +90,7 @@ const pkceVerifierTTL = 10 * time.Minute
// proxyConnection represents a connected proxy
type proxyConnection struct {
proxyID string
sessionID string
address string
capabilities *proto.ProxyCapabilities
stream proto.ProxyService_GetMappingUpdateServer
@@ -166,9 +168,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
}
sessionID := uuid.NewString()
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
oldConn := old.(*proxyConnection)
log.WithFields(log.Fields{
"proxy_id": proxyID,
"old_session_id": oldConn.sessionID,
"new_session_id": sessionID,
}).Info("Superseding existing proxy connection")
oldConn.cancel()
}
connCtx, cancel := context.WithCancel(ctx)
conn := &proxyConnection{
proxyID: proxyID,
sessionID: sessionID,
address: proxyAddress,
capabilities: req.GetCapabilities(),
stream: stream,
@@ -188,12 +203,13 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
caps = &proxy.Capabilities{
SupportsCustomPorts: c.SupportsCustomPorts,
RequireSubdomain: c.RequireSubdomain,
SupportsCrowdsec: c.SupportsCrowdsec,
SupportsCrowdsec: c.SupportsCrowdsec,
}
}
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
if err != nil {
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
s.connectedProxies.Delete(proxyID)
s.connectedProxies.CompareAndDelete(proxyID, conn)
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
}
@@ -202,22 +218,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
log.WithFields(log.Fields{
"proxy_id": proxyID,
"session_id": sessionID,
"address": proxyAddress,
"cluster_addr": proxyAddress,
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
cancel()
return
}
s.connectedProxies.Delete(proxyID)
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
cancel()
log.Infof("Proxy %s disconnected", proxyID)
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
}()
if err := s.sendSnapshot(ctx, conn); err != nil {
@@ -227,29 +248,31 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
errChan := make(chan error, 2)
go s.sender(conn, errChan)
// Start heartbeat goroutine
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
go s.heartbeat(connCtx, proxyRecord)
select {
case err := <-errChan:
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
case <-connCtx.Done():
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
return connCtx.Err()
}
}
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
}
case <-ctx.Done():
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
return
}
}