mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Add graceful shutdown for Kubernetes
This commit is contained in:
160
proxy/server.go
160
proxy/server.go
@@ -19,6 +19,7 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
backoff "github.com/cenkalti/backoff/v4"
|
||||
@@ -36,6 +37,7 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
"github.com/netbirdio/netbird/proxy/internal/k8s"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
@@ -71,10 +73,13 @@ type Server struct {
|
||||
GenerateACMECertificates bool
|
||||
ACMEChallengeAddress string
|
||||
ACMEDirectory string
|
||||
OIDCClientId string
|
||||
OIDCClientSecret string
|
||||
OIDCEndpoint string
|
||||
OIDCScopes []string
|
||||
// CertLockMethod controls how ACME certificate locks are coordinated
|
||||
// across replicas. Default: CertLockAuto (detect environment).
|
||||
CertLockMethod acme.CertLockMethod
|
||||
OIDCClientId string
|
||||
OIDCClientSecret string
|
||||
OIDCEndpoint string
|
||||
OIDCScopes []string
|
||||
|
||||
// DebugEndpointEnabled enables the debug HTTP endpoint.
|
||||
DebugEndpointEnabled bool
|
||||
@@ -173,6 +178,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create management connection: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := mgmtConn.Close(); err != nil {
|
||||
s.Logger.Debugf("management connection close: %v", err)
|
||||
}
|
||||
}()
|
||||
s.mgmtClient = proto.NewProxyServiceClient(mgmtConn)
|
||||
go s.newManagementMappingWorker(ctx, s.mgmtClient)
|
||||
|
||||
@@ -184,17 +194,14 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
tlsConfig := &tls.Config{}
|
||||
if s.GenerateACMECertificates {
|
||||
s.Logger.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager")
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger)
|
||||
s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod)
|
||||
s.http = &http.Server{
|
||||
Addr: s.ACMEChallengeAddress,
|
||||
Handler: s.acme.HTTPHandler(nil),
|
||||
}
|
||||
go func() {
|
||||
if err := s.http.ListenAndServe(); err != nil {
|
||||
// Rather than retry, log the issue periodically so that hopefully someone notices and fixes the issue.
|
||||
for range time.Tick(10 * time.Second) {
|
||||
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server error")
|
||||
}
|
||||
if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed")
|
||||
}
|
||||
}()
|
||||
tlsConfig = s.acme.TLSConfig()
|
||||
@@ -250,14 +257,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.Logger.Errorf("debug endpoint error: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := s.debug.Close(); err != nil {
|
||||
s.Logger.Debugf("debug endpoint close: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start health probe server on separate port for Kubernetes probes.
|
||||
// Start health probe server.
|
||||
healthAddr := s.HealthAddress
|
||||
if healthAddr == "" {
|
||||
healthAddr = "localhost:8080"
|
||||
@@ -272,30 +274,120 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
s.Logger.Errorf("health probe server: %v", err)
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.healthServer.Shutdown(shutdownCtx); err != nil {
|
||||
s.Logger.Debugf("health probe server shutdown: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := s.netbird.StopAll(stopCtx); err != nil {
|
||||
s.Logger.Warnf("failed to stop all netbird clients: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Finally, start the reverse proxy.
|
||||
// Start the reverse proxy HTTPS server.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy))),
|
||||
TLSConfig: tlsConfig,
|
||||
}
|
||||
s.Logger.Debugf("starting listening on reverse proxy server address %s", addr)
|
||||
return s.https.ListenAndServeTLS("", "")
|
||||
|
||||
httpsErr := make(chan error, 1)
|
||||
go func() {
|
||||
s.Logger.Debugf("starting reverse proxy server on %s", addr)
|
||||
httpsErr <- s.https.ListenAndServeTLS("", "")
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-httpsErr:
|
||||
s.shutdownServices()
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("https server: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
s.gracefulShutdown()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
// shutdownPreStopDelay is the time to wait after receiving a shutdown signal
|
||||
// before draining connections. This allows the load balancer to propagate
|
||||
// the endpoint removal.
|
||||
shutdownPreStopDelay = 5 * time.Second
|
||||
|
||||
// shutdownDrainTimeout is the maximum time to wait for in-flight HTTP
|
||||
// requests to complete during graceful shutdown.
|
||||
shutdownDrainTimeout = 30 * time.Second
|
||||
|
||||
// shutdownServiceTimeout is the maximum time to wait for auxiliary
|
||||
// services (health probe, debug endpoint, ACME) to shut down.
|
||||
shutdownServiceTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// gracefulShutdown performs a zero-downtime shutdown sequence. It marks the
|
||||
// readiness probe as failing, waits for load balancer propagation, drains
|
||||
// in-flight connections, and then stops all background services.
|
||||
func (s *Server) gracefulShutdown() {
|
||||
s.Logger.Info("shutdown signal received, starting graceful shutdown")
|
||||
|
||||
// Step 1: Fail readiness probe so load balancers stop routing new traffic.
|
||||
if s.healthChecker != nil {
|
||||
s.healthChecker.SetShuttingDown()
|
||||
}
|
||||
|
||||
// Step 2: When running behind a load balancer, wait for endpoint removal
|
||||
// to propagate before draining connections.
|
||||
if k8s.InCluster() {
|
||||
s.Logger.Infof("waiting %s for load balancer propagation", shutdownPreStopDelay)
|
||||
time.Sleep(shutdownPreStopDelay)
|
||||
}
|
||||
|
||||
// Step 3: Stop accepting new connections and drain in-flight requests.
|
||||
drainCtx, drainCancel := context.WithTimeout(context.Background(), shutdownDrainTimeout)
|
||||
defer drainCancel()
|
||||
|
||||
s.Logger.Info("draining in-flight connections")
|
||||
if err := s.https.Shutdown(drainCtx); err != nil {
|
||||
s.Logger.Warnf("https server drain: %v", err)
|
||||
}
|
||||
|
||||
// Step 4: Stop all remaining background services.
|
||||
s.shutdownServices()
|
||||
s.Logger.Info("graceful shutdown complete")
|
||||
}
|
||||
|
||||
// shutdownServices stops all background services concurrently and waits for
|
||||
// them to finish.
|
||||
func (s *Server) shutdownServices() {
|
||||
var wg sync.WaitGroup
|
||||
|
||||
shutdownHTTP := func(name string, shutdown func(context.Context) error) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), shutdownServiceTimeout)
|
||||
defer cancel()
|
||||
if err := shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.Logger.Debugf("%s shutdown: %v", name, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if s.healthServer != nil {
|
||||
shutdownHTTP("health probe", s.healthServer.Shutdown)
|
||||
}
|
||||
if s.debug != nil {
|
||||
shutdownHTTP("debug endpoint", s.debug.Shutdown)
|
||||
}
|
||||
if s.http != nil {
|
||||
shutdownHTTP("acme http", s.http.Shutdown)
|
||||
}
|
||||
|
||||
if s.netbird != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout)
|
||||
defer cancel()
|
||||
if err := s.netbird.StopAll(ctx); err != nil {
|
||||
s.Logger.Warnf("stop netbird clients: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) {
|
||||
|
||||
Reference in New Issue
Block a user