Add graceful shutdown for Kubernetes

This commit is contained in:
Viktor Liu
2026-02-09 20:17:12 +08:00
parent fd442138e6
commit 53c1016a8e
6 changed files with 158 additions and 39 deletions

View File

@@ -19,6 +19,7 @@ import (
"net/netip"
"net/url"
"path/filepath"
"sync"
"time"
backoff "github.com/cenkalti/backoff/v4"
@@ -36,6 +37,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/debug"
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/k8s"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
@@ -71,10 +73,13 @@ type Server struct {
GenerateACMECertificates bool
ACMEChallengeAddress string
ACMEDirectory string
OIDCClientId string
OIDCClientSecret string
OIDCEndpoint string
OIDCScopes []string
// CertLockMethod controls how ACME certificate locks are coordinated
// across replicas. Default: CertLockAuto (detect environment).
CertLockMethod acme.CertLockMethod
OIDCClientId string
OIDCClientSecret string
OIDCEndpoint string
OIDCScopes []string
// DebugEndpointEnabled enables the debug HTTP endpoint.
DebugEndpointEnabled bool
@@ -173,6 +178,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
if err != nil {
return fmt.Errorf("could not create management connection: %w", err)
}
defer func() {
if err := mgmtConn.Close(); err != nil {
s.Logger.Debugf("management connection close: %v", err)
}
}()
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
go s.newManagementMappingWorker(ctx, s.mgmtClient)
@@ -184,17 +194,14 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
tlsConfig := &tls.Config{}
if s.GenerateACMECertificates {
s.Logger.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager")
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger)
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
s.http = &http.Server{
Addr: s.ACMEChallengeAddress,
Handler: s.acme.HTTPHandler(nil),
}
go func() {
if err := s.http.ListenAndServe(); err != nil {
// Rather than retry, log the issue periodically so that hopefully someone notices and fixes the issue.
for range time.Tick(10 * time.Second) {
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server error")
}
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
}
}()
tlsConfig = s.acme.TLSConfig()
@@ -250,14 +257,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.Logger.Errorf("debug endpoint error: %v", err)
}
}()
defer func() {
if err := s.debug.Close(); err != nil {
s.Logger.Debugf("debug endpoint close: %v", err)
}
}()
}
// Start health probe server on separate port for Kubernetes probes.
// Start health probe server.
healthAddr := s.HealthAddress
if healthAddr == "" {
healthAddr = "localhost:8080"
@@ -272,30 +274,120 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.Logger.Errorf("health probe server: %v", err)
}
}()
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.healthServer.Shutdown(shutdownCtx); err != nil {
s.Logger.Debugf("health probe server shutdown: %v", err)
}
}()
defer func() {
stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.netbird.StopAll(stopCtx); err != nil {
s.Logger.Warnf("failed to stop all netbird clients: %v", err)
}
}()
// Finally, start the reverse proxy.
// Start the reverse proxy HTTPS server.
s.https = &http.Server{
Addr: addr,
Handler: accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy))),
TLSConfig: tlsConfig,
}
s.Logger.Debugf("starting listening on reverse proxy server address %s", addr)
return s.https.ListenAndServeTLS("", "")
httpsErr := make(chan error, 1)
go func() {
s.Logger.Debugf("starting reverse proxy server on %s", addr)
httpsErr <- s.https.ListenAndServeTLS("", "")
}()
select {
case err := <-httpsErr:
s.shutdownServices()
if !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("https server: %w", err)
}
return nil
case <-ctx.Done():
s.gracefulShutdown()
return nil
}
}
const (
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
// before draining connections. This allows the load balancer to propagate
// the endpoint removal.
shutdownPreStopDelay = 5 * time.Second
// shutdownDrainTimeout is the maximum time to wait for in-flight HTTP
// requests to complete during graceful shutdown.
shutdownDrainTimeout = 30 * time.Second
// shutdownServiceTimeout is the maximum time to wait for auxiliary
// services (health probe, debug endpoint, ACME) to shut down.
shutdownServiceTimeout = 5 * time.Second
)
// gracefulShutdown performs a zero-downtime shutdown sequence. It marks the
// readiness probe as failing, waits for load balancer propagation, drains
// in-flight connections, and then stops all background services.
func (s *Server) gracefulShutdown() {
s.Logger.Info("shutdown signal received, starting graceful shutdown")
// Step 1: Fail readiness probe so load balancers stop routing new traffic.
if s.healthChecker != nil {
s.healthChecker.SetShuttingDown()
}
// Step 2: When running behind a load balancer, wait for endpoint removal
// to propagate before draining connections.
if k8s.InCluster() {
s.Logger.Infof("waiting %s for load balancer propagation", shutdownPreStopDelay)
time.Sleep(shutdownPreStopDelay)
}
// Step 3: Stop accepting new connections and drain in-flight requests.
drainCtx, drainCancel := context.WithTimeout(context.Background(), shutdownDrainTimeout)
defer drainCancel()
s.Logger.Info("draining in-flight connections")
if err := s.https.Shutdown(drainCtx); err != nil {
s.Logger.Warnf("https server drain: %v", err)
}
// Step 4: Stop all remaining background services.
s.shutdownServices()
s.Logger.Info("graceful shutdown complete")
}
// shutdownServices stops all background services concurrently and waits for
// them to finish.
func (s *Server) shutdownServices() {
var wg sync.WaitGroup
shutdownHTTP := func(name string, shutdown func(context.Context) error) {
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), shutdownServiceTimeout)
defer cancel()
if err := shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Debugf("%s shutdown: %v", name, err)
}
}()
}
if s.healthServer != nil {
shutdownHTTP("health probe", s.healthServer.Shutdown)
}
if s.debug != nil {
shutdownHTTP("debug endpoint", s.debug.Shutdown)
}
if s.http != nil {
shutdownHTTP("acme http", s.http.Shutdown)
}
if s.netbird != nil {
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout)
defer cancel()
if err := s.netbird.StopAll(ctx); err != nil {
s.Logger.Warnf("stop netbird clients: %v", err)
}
}()
}
wg.Wait()
}
func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) {