mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 00:06:38 +00:00
Add graceful shutdown for Kubernetes
This commit is contained in:
@@ -4,14 +4,17 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/acme"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy"
|
||||
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -47,6 +50,7 @@ var (
|
||||
trustedProxies string
|
||||
certFile string
|
||||
certKeyFile string
|
||||
certLockMethod string
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
@@ -78,6 +82,7 @@ func init() {
|
||||
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')")
|
||||
rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory")
|
||||
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
|
||||
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
|
||||
}
|
||||
|
||||
// Execute runs the root command.
|
||||
@@ -145,9 +150,13 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
OIDCScopes: strings.Split(oidcScopes, ","),
|
||||
ForwardedProto: forwardedProto,
|
||||
TrustedProxies: parsedTrustedProxies,
|
||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||
}
|
||||
|
||||
if err := srv.ListenAndServe(context.TODO(), addr); err != nil {
|
||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||
defer stop()
|
||||
|
||||
if err := srv.ListenAndServe(ctx, addr); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -68,6 +68,9 @@ type flockLocker struct {
|
||||
}
|
||||
|
||||
func newFlockLocker(certDir string, logger *log.Logger) *flockLocker {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &flockLocker{certDir: certDir, logger: logger}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,6 +32,9 @@ func Lock(ctx context.Context, path string) (*os.File, error) {
|
||||
return nil, fmt.Errorf("open lock file %s: %w", path, err)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(retryInterval)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err == nil {
|
||||
return f, nil
|
||||
@@ -48,7 +51,8 @@ func Lock(ctx context.Context, path string) (*os.File, error) {
|
||||
log.Debugf("close lock file %s: %v", path, cerr)
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(retryInterval):
|
||||
case <-timer.C:
|
||||
timer.Reset(retryInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
maxConcurrentChecks = 3
|
||||
maxClientCheckTimeout = 5 * time.Minute
|
||||
maxConcurrentChecks = 3
|
||||
maxClientCheckTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
// clientProvider provides access to NetBird clients for health checks.
|
||||
@@ -34,6 +34,7 @@ type Checker struct {
|
||||
mu sync.RWMutex
|
||||
managementConnected bool
|
||||
initialSyncComplete bool
|
||||
shuttingDown bool
|
||||
|
||||
// checkSem limits concurrent client health checks.
|
||||
checkSem chan struct{}
|
||||
@@ -77,6 +78,14 @@ func (c *Checker) SetInitialSyncComplete() {
|
||||
c.initialSyncComplete = true
|
||||
}
|
||||
|
||||
// SetShuttingDown marks the server as shutting down.
|
||||
// This causes ReadinessProbe to return false so load balancers stop routing traffic.
|
||||
func (c *Checker) SetShuttingDown() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.shuttingDown = true
|
||||
}
|
||||
|
||||
// CheckClientsConnected verifies all clients are connected to management/signal/relay.
|
||||
// Uses the provided context for timeout/cancellation, with a maximum bound of maxClientCheckTimeout.
|
||||
// Limits concurrent checks via semaphore.
|
||||
@@ -145,6 +154,9 @@ func (c *Checker) LivenessProbe() bool {
|
||||
func (c *Checker) ReadinessProbe() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.shuttingDown {
|
||||
return false
|
||||
}
|
||||
return c.managementConnected
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
160
proxy/server.go
160
proxy/server.go
@@ -19,6 +19,7 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
backoff "github.com/cenkalti/backoff/v4"
|
||||
@@ -36,6 +37,7 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
"github.com/netbirdio/netbird/proxy/internal/k8s"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
@@ -71,10 +73,13 @@ type Server struct {
|
||||
GenerateACMECertificates bool
|
||||
ACMEChallengeAddress string
|
||||
ACMEDirectory string
|
||||
OIDCClientId string
|
||||
OIDCClientSecret string
|
||||
OIDCEndpoint string
|
||||
OIDCScopes []string
|
||||
// CertLockMethod controls how ACME certificate locks are coordinated
|
||||
// across replicas. Default: CertLockAuto (detect environment).
|
||||
CertLockMethod acme.CertLockMethod
|
||||
OIDCClientId string
|
||||
OIDCClientSecret string
|
||||
OIDCEndpoint string
|
||||
OIDCScopes []string
|
||||
|
||||
// DebugEndpointEnabled enables the debug HTTP endpoint.
|
||||
DebugEndpointEnabled bool
|
||||
@@ -173,6 +178,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create management connection: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := mgmtConn.Close(); err != nil {
|
||||
s.Logger.Debugf("management connection close: %v", err)
|
||||
}
|
||||
}()
|
||||
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
||||
go s.newManagementMappingWorker(ctx, s.mgmtClient)
|
||||
|
||||
@@ -184,17 +194,14 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
if s.GenerateACMECertificates {
|
||||
s.Logger.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager")
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger)
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
|
||||
s.http = &http.Server{
|
||||
Addr: s.ACMEChallengeAddress,
|
||||
Handler: s.acme.HTTPHandler(nil),
|
||||
}
|
||||
go func() {
|
||||
if err := s.http.ListenAndServe(); err != nil {
|
||||
// Rather than retry, log the issue periodically so that hopefully someone notices and fixes the issue.
|
||||
for range time.Tick(10 * time.Second) {
|
||||
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server error")
|
||||
}
|
||||
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
|
||||
}
|
||||
}()
|
||||
tlsConfig = s.acme.TLSConfig()
|
||||
@@ -250,14 +257,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := s.debug.Close(); err != nil {
|
||||
s.Logger.Debugf("debug endpoint close: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start health probe server on separate port for Kubernetes probes.
|
||||
// Start health probe server.
|
||||
healthAddr := s.HealthAddress
|
||||
if healthAddr == "" {
|
||||
healthAddr = "localhost:8080"
|
||||
@@ -272,30 +274,120 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.Logger.Errorf("health probe server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.healthServer.Shutdown(shutdownCtx); err != nil {
|
||||
s.Logger.Debugf("health probe server shutdown: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := s.netbird.StopAll(stopCtx); err != nil {
|
||||
s.Logger.Warnf("failed to stop all netbird clients: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Finally, start the reverse proxy.
|
||||
// Start the reverse proxy HTTPS server.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy))),
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
s.Logger.Debugf("starting listening on reverse proxy server address %s", addr)
|
||||
return s.https.ListenAndServeTLS("", "")
|
||||
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debugf("starting reverse proxy server on %s", addr)
|
||||
httpsErr <- s.https.ListenAndServeTLS("", "")
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-httpsErr:
|
||||
s.shutdownServices()
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("https server: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
s.gracefulShutdown()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
|
||||
// before draining connections. This allows the load balancer to propagate
|
||||
// the endpoint removal.
|
||||
shutdownPreStopDelay = 5 * time.Second
|
||||
|
||||
// shutdownDrainTimeout is the maximum time to wait for in-flight HTTP
|
||||
// requests to complete during graceful shutdown.
|
||||
shutdownDrainTimeout = 30 * time.Second
|
||||
|
||||
// shutdownServiceTimeout is the maximum time to wait for auxiliary
|
||||
// services (health probe, debug endpoint, ACME) to shut down.
|
||||
shutdownServiceTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// gracefulShutdown performs a zero-downtime shutdown sequence. It marks the
|
||||
// readiness probe as failing, waits for load balancer propagation, drains
|
||||
// in-flight connections, and then stops all background services.
|
||||
func (s *Server) gracefulShutdown() {
|
||||
s.Logger.Info("shutdown signal received, starting graceful shutdown")
|
||||
|
||||
// Step 1: Fail readiness probe so load balancers stop routing new traffic.
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetShuttingDown()
|
||||
}
|
||||
|
||||
// Step 2: When running behind a load balancer, wait for endpoint removal
|
||||
// to propagate before draining connections.
|
||||
if k8s.InCluster() {
|
||||
s.Logger.Infof("waiting %s for load balancer propagation", shutdownPreStopDelay)
|
||||
time.Sleep(shutdownPreStopDelay)
|
||||
}
|
||||
|
||||
// Step 3: Stop accepting new connections and drain in-flight requests.
|
||||
drainCtx, drainCancel := context.WithTimeout(context.Background(), shutdownDrainTimeout)
|
||||
defer drainCancel()
|
||||
|
||||
s.Logger.Info("draining in-flight connections")
|
||||
if err := s.https.Shutdown(drainCtx); err != nil {
|
||||
s.Logger.Warnf("https server drain: %v", err)
|
||||
}
|
||||
|
||||
// Step 4: Stop all remaining background services.
|
||||
s.shutdownServices()
|
||||
s.Logger.Info("graceful shutdown complete")
|
||||
}
|
||||
|
||||
// shutdownServices stops all background services concurrently and waits for
|
||||
// them to finish.
|
||||
func (s *Server) shutdownServices() {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
shutdownHTTP := func(name string, shutdown func(context.Context) error) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), shutdownServiceTimeout)
|
||||
defer cancel()
|
||||
if err := shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Debugf("%s shutdown: %v", name, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if s.healthServer != nil {
|
||||
shutdownHTTP("health probe", s.healthServer.Shutdown)
|
||||
}
|
||||
if s.debug != nil {
|
||||
shutdownHTTP("debug endpoint", s.debug.Shutdown)
|
||||
}
|
||||
if s.http != nil {
|
||||
shutdownHTTP("acme http", s.http.Shutdown)
|
||||
}
|
||||
|
||||
if s.netbird != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout)
|
||||
defer cancel()
|
||||
if err := s.netbird.StopAll(ctx); err != nil {
|
||||
s.Logger.Warnf("stop netbird clients: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) {
|
||||
|
||||
Reference in New Issue
Block a user