Add graceful shutdown for Kubernetes

This commit is contained in:
Viktor Liu
2026-02-09 20:17:12 +08:00
parent fd442138e6
commit 53c1016a8e
6 changed files with 158 additions and 39 deletions

View File

@@ -4,14 +4,17 @@ import (
"context"
"fmt"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/acme"
"github.com/netbirdio/netbird/proxy"
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/util"
)
@@ -47,6 +50,7 @@ var (
trustedProxies string
certFile string
certKeyFile string
certLockMethod string
)
var rootCmd = &cobra.Command{
@@ -78,6 +82,7 @@ func init() {
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')")
rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory")
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
}
// Execute runs the root command.
@@ -145,9 +150,13 @@ func runServer(cmd *cobra.Command, args []string) error {
OIDCScopes: strings.Split(oidcScopes, ","),
ForwardedProto: forwardedProto,
TrustedProxies: parsedTrustedProxies,
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
}
if err := srv.ListenAndServe(context.TODO(), addr); err != nil {
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
defer stop()
if err := srv.ListenAndServe(ctx, addr); err != nil {
log.Fatal(err)
}
return nil

View File

@@ -68,6 +68,9 @@ type flockLocker struct {
}
func newFlockLocker(certDir string, logger *log.Logger) *flockLocker {
if logger == nil {
logger = log.StandardLogger()
}
return &flockLocker{certDir: certDir, logger: logger}
}

View File

@@ -32,6 +32,9 @@ func Lock(ctx context.Context, path string) (*os.File, error) {
return nil, fmt.Errorf("open lock file %s: %w", path, err)
}
timer := time.NewTimer(retryInterval)
defer timer.Stop()
for {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err == nil {
return f, nil
@@ -48,7 +51,8 @@ func Lock(ctx context.Context, path string) (*os.File, error) {
log.Debugf("close lock file %s: %v", path, cerr)
}
return nil, ctx.Err()
case <-time.After(retryInterval):
case <-timer.C:
timer.Reset(retryInterval)
}
}
}

View File

@@ -17,8 +17,8 @@ import (
)
const (
maxConcurrentChecks = 3
maxClientCheckTimeout = 5 * time.Minute
maxConcurrentChecks = 3
maxClientCheckTimeout = 5 * time.Minute
)
// clientProvider provides access to NetBird clients for health checks.
@@ -34,6 +34,7 @@ type Checker struct {
mu sync.RWMutex
managementConnected bool
initialSyncComplete bool
shuttingDown bool
// checkSem limits concurrent client health checks.
checkSem chan struct{}
@@ -77,6 +78,14 @@ func (c *Checker) SetInitialSyncComplete() {
c.initialSyncComplete = true
}
// SetShuttingDown marks the server as shutting down.
// This causes ReadinessProbe to return false so load balancers stop routing traffic.
func (c *Checker) SetShuttingDown() {
c.mu.Lock()
defer c.mu.Unlock()
c.shuttingDown = true
}
// CheckClientsConnected verifies all clients are connected to management/signal/relay.
// Uses the provided context for timeout/cancellation, with a maximum bound of maxClientCheckTimeout.
// Limits concurrent checks via semaphore.
@@ -145,6 +154,9 @@ func (c *Checker) LivenessProbe() bool {
func (c *Checker) ReadinessProbe() bool {
c.mu.RLock()
defer c.mu.RUnlock()
if c.shuttingDown {
return false
}
return c.managementConnected
}

View File

@@ -10,7 +10,6 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/hex"
"bytes"
"encoding/json"
"errors"
"fmt"

View File

@@ -19,6 +19,7 @@ import (
"net/netip"
"net/url"
"path/filepath"
"sync"
"time"
backoff "github.com/cenkalti/backoff/v4"
@@ -36,6 +37,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/debug"
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/k8s"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
@@ -71,10 +73,13 @@ type Server struct {
GenerateACMECertificates bool
ACMEChallengeAddress string
ACMEDirectory string
OIDCClientId string
OIDCClientSecret string
OIDCEndpoint string
OIDCScopes []string
// CertLockMethod controls how ACME certificate locks are coordinated
// across replicas. Default: CertLockAuto (detect environment).
CertLockMethod acme.CertLockMethod
OIDCClientId string
OIDCClientSecret string
OIDCEndpoint string
OIDCScopes []string
// DebugEndpointEnabled enables the debug HTTP endpoint.
DebugEndpointEnabled bool
@@ -173,6 +178,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
if err != nil {
return fmt.Errorf("could not create management connection: %w", err)
}
defer func() {
if err := mgmtConn.Close(); err != nil {
s.Logger.Debugf("management connection close: %v", err)
}
}()
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
go s.newManagementMappingWorker(ctx, s.mgmtClient)
@@ -184,17 +194,14 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
tlsConfig := &tls.Config{}
if s.GenerateACMECertificates {
s.Logger.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager")
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger)
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
s.http = &http.Server{
Addr: s.ACMEChallengeAddress,
Handler: s.acme.HTTPHandler(nil),
}
go func() {
if err := s.http.ListenAndServe(); err != nil {
// Rather than retry, log the issue periodically so that hopefully someone notices and fixes the issue.
for range time.Tick(10 * time.Second) {
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server error")
}
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
}
}()
tlsConfig = s.acme.TLSConfig()
@@ -250,14 +257,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.Logger.Errorf("debug endpoint error: %v", err)
}
}()
defer func() {
if err := s.debug.Close(); err != nil {
s.Logger.Debugf("debug endpoint close: %v", err)
}
}()
}
// Start health probe server on separate port for Kubernetes probes.
// Start health probe server.
healthAddr := s.HealthAddress
if healthAddr == "" {
healthAddr = "localhost:8080"
@@ -272,30 +274,120 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
s.Logger.Errorf("health probe server: %v", err)
}
}()
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.healthServer.Shutdown(shutdownCtx); err != nil {
s.Logger.Debugf("health probe server shutdown: %v", err)
}
}()
defer func() {
stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.netbird.StopAll(stopCtx); err != nil {
s.Logger.Warnf("failed to stop all netbird clients: %v", err)
}
}()
// Finally, start the reverse proxy.
// Start the reverse proxy HTTPS server.
s.https = &http.Server{
Addr: addr,
Handler: accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy))),
TLSConfig: tlsConfig,
}
s.Logger.Debugf("starting listening on reverse proxy server address %s", addr)
return s.https.ListenAndServeTLS("", "")
httpsErr := make(chan error, 1)
go func() {
s.Logger.Debugf("starting reverse proxy server on %s", addr)
httpsErr <- s.https.ListenAndServeTLS("", "")
}()
select {
case err := <-httpsErr:
s.shutdownServices()
if !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("https server: %w", err)
}
return nil
case <-ctx.Done():
s.gracefulShutdown()
return nil
}
}
const (
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
// before draining connections. This allows the load balancer to propagate
// the endpoint removal.
shutdownPreStopDelay = 5 * time.Second
// shutdownDrainTimeout is the maximum time to wait for in-flight HTTP
// requests to complete during graceful shutdown.
shutdownDrainTimeout = 30 * time.Second
// shutdownServiceTimeout is the maximum time to wait for auxiliary
// services (health probe, debug endpoint, ACME) to shut down.
shutdownServiceTimeout = 5 * time.Second
)
// gracefulShutdown performs a zero-downtime shutdown sequence. It marks the
// readiness probe as failing, waits for load balancer propagation, drains
// in-flight connections, and then stops all background services.
func (s *Server) gracefulShutdown() {
s.Logger.Info("shutdown signal received, starting graceful shutdown")
// Step 1: Fail readiness probe so load balancers stop routing new traffic.
if s.healthChecker != nil {
s.healthChecker.SetShuttingDown()
}
// Step 2: When running behind a load balancer, wait for endpoint removal
// to propagate before draining connections.
if k8s.InCluster() {
s.Logger.Infof("waiting %s for load balancer propagation", shutdownPreStopDelay)
time.Sleep(shutdownPreStopDelay)
}
// Step 3: Stop accepting new connections and drain in-flight requests.
drainCtx, drainCancel := context.WithTimeout(context.Background(), shutdownDrainTimeout)
defer drainCancel()
s.Logger.Info("draining in-flight connections")
if err := s.https.Shutdown(drainCtx); err != nil {
s.Logger.Warnf("https server drain: %v", err)
}
// Step 4: Stop all remaining background services.
s.shutdownServices()
s.Logger.Info("graceful shutdown complete")
}
// shutdownServices stops all background services concurrently and waits for
// them to finish.
func (s *Server) shutdownServices() {
var wg sync.WaitGroup
shutdownHTTP := func(name string, shutdown func(context.Context) error) {
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), shutdownServiceTimeout)
defer cancel()
if err := shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Debugf("%s shutdown: %v", name, err)
}
}()
}
if s.healthServer != nil {
shutdownHTTP("health probe", s.healthServer.Shutdown)
}
if s.debug != nil {
shutdownHTTP("debug endpoint", s.debug.Shutdown)
}
if s.http != nil {
shutdownHTTP("acme http", s.http.Shutdown)
}
if s.netbird != nil {
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout)
defer cancel()
if err := s.netbird.StopAll(ctx); err != nil {
s.Logger.Warnf("stop netbird clients: %v", err)
}
}()
}
wg.Wait()
}
func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) {