diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 450927fd8..edca019e8 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -4,14 +4,17 @@ import ( "context" "fmt" "os" + "os/signal" "strconv" "strings" + "syscall" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "golang.org/x/crypto/acme" "github.com/netbirdio/netbird/proxy" + nbacme "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/util" ) @@ -47,6 +50,7 @@ var ( trustedProxies string certFile string certKeyFile string + certLockMethod string ) var rootCmd = &cobra.Command{ @@ -78,6 +82,7 @@ func init() { rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')") rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory") rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory") + rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease") } // Execute runs the root command. @@ -145,9 +150,13 @@ func runServer(cmd *cobra.Command, args []string) error { OIDCScopes: strings.Split(oidcScopes, ","), ForwardedProto: forwardedProto, TrustedProxies: parsedTrustedProxies, + CertLockMethod: nbacme.CertLockMethod(certLockMethod), } - if err := srv.ListenAndServe(context.TODO(), addr); err != nil { + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + defer stop() + + if err := srv.ListenAndServe(ctx, addr); err != nil { log.Fatal(err) } return nil diff --git a/proxy/internal/acme/locker.go b/proxy/internal/acme/locker.go index 11c3abdfe..1ab330603 100644 --- a/proxy/internal/acme/locker.go +++ b/proxy/internal/acme/locker.go @@ -68,6 +68,9 @@ type flockLocker struct { } func newFlockLocker(certDir string, logger *log.Logger) *flockLocker { + if logger == nil { + logger = log.StandardLogger() + } return &flockLocker{certDir: certDir, logger: logger} } diff --git a/proxy/internal/flock/flock_unix.go b/proxy/internal/flock/flock_unix.go index 84a77e1dc..ceb47ff7a 100644 --- a/proxy/internal/flock/flock_unix.go +++ b/proxy/internal/flock/flock_unix.go @@ -32,6 +32,9 @@ func Lock(ctx context.Context, path string) (*os.File, error) { return nil, fmt.Errorf("open lock file %s: %w", path, err) } + timer := time.NewTimer(retryInterval) + defer timer.Stop() + for { if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err == nil { return f, nil @@ -48,7 +51,8 @@ func Lock(ctx context.Context, path string) (*os.File, error) { log.Debugf("close lock file %s: %v", path, cerr) } return nil, ctx.Err() - case <-time.After(retryInterval): + case <-timer.C: + timer.Reset(retryInterval) } } } diff --git a/proxy/internal/health/health.go b/proxy/internal/health/health.go index 36ed51674..bef968d27 100644 --- a/proxy/internal/health/health.go +++ b/proxy/internal/health/health.go @@ -17,8 +17,8 @@ import ( ) const ( - maxConcurrentChecks = 3 - maxClientCheckTimeout = 5 * time.Minute + maxConcurrentChecks = 3 + maxClientCheckTimeout = 5 * time.Minute ) // clientProvider provides access to NetBird clients for health checks. @@ -34,6 +34,7 @@ type Checker struct { mu sync.RWMutex managementConnected bool initialSyncComplete bool + shuttingDown bool // checkSem limits concurrent client health checks. checkSem chan struct{} @@ -77,6 +78,14 @@ func (c *Checker) SetInitialSyncComplete() { c.initialSyncComplete = true } +// SetShuttingDown marks the server as shutting down. +// This causes ReadinessProbe to return false so load balancers stop routing traffic. +func (c *Checker) SetShuttingDown() { + c.mu.Lock() + defer c.mu.Unlock() + c.shuttingDown = true +} + // CheckClientsConnected verifies all clients are connected to management/signal/relay. // Uses the provided context for timeout/cancellation, with a maximum bound of maxClientCheckTimeout. // Limits concurrent checks via semaphore. @@ -145,6 +154,9 @@ func (c *Checker) LivenessProbe() bool { func (c *Checker) ReadinessProbe() bool { c.mu.RLock() defer c.mu.RUnlock() + if c.shuttingDown { + return false + } return c.managementConnected } diff --git a/proxy/internal/k8s/lease.go b/proxy/internal/k8s/lease.go index cef15521a..4b67db80d 100644 --- a/proxy/internal/k8s/lease.go +++ b/proxy/internal/k8s/lease.go @@ -10,7 +10,6 @@ import ( "crypto/tls" "crypto/x509" "encoding/hex" - "bytes" "encoding/json" "errors" "fmt" diff --git a/proxy/server.go b/proxy/server.go index c981e6f84..fc219933f 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -19,6 +19,7 @@ import ( "net/netip" "net/url" "path/filepath" + "sync" "time" backoff "github.com/cenkalti/backoff/v4" @@ -36,6 +37,7 @@ import ( "github.com/netbirdio/netbird/proxy/internal/debug" proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" "github.com/netbirdio/netbird/proxy/internal/health" + "github.com/netbirdio/netbird/proxy/internal/k8s" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/roundtrip" "github.com/netbirdio/netbird/proxy/internal/types" @@ -71,10 +73,13 @@ type Server struct { GenerateACMECertificates bool ACMEChallengeAddress string ACMEDirectory string - OIDCClientId string - OIDCClientSecret string - OIDCEndpoint string - OIDCScopes []string + // CertLockMethod controls how ACME certificate locks are coordinated + // across replicas. Default: CertLockAuto (detect environment). + CertLockMethod acme.CertLockMethod + OIDCClientId string + OIDCClientSecret string + OIDCEndpoint string + OIDCScopes []string // DebugEndpointEnabled enables the debug HTTP endpoint. DebugEndpointEnabled bool @@ -173,6 +178,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { if err != nil { return fmt.Errorf("could not create management connection: %w", err) } + defer func() { + if err := mgmtConn.Close(); err != nil { + s.Logger.Debugf("management connection close: %v", err) + } + }() s.mgmtClient = proto.NewProxyServiceClient(mgmtConn) go s.newManagementMappingWorker(ctx, s.mgmtClient) @@ -184,17 +194,14 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { tlsConfig := &tls.Config{} if s.GenerateACMECertificates { s.Logger.WithField("acme_server", s.ACMEDirectory).Debug("ACME certificates enabled, configuring certificate manager") - s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger) + s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod) s.http = &http.Server{ Addr: s.ACMEChallengeAddress, Handler: s.acme.HTTPHandler(nil), } go func() { - if err := s.http.ListenAndServe(); err != nil { - // Rather than retry, log the issue periodically so that hopefully someone notices and fixes the issue. - for range time.Tick(10 * time.Second) { - s.Logger.WithError(err).Error("ACME HTTP-01 challenge server error") - } + if err := s.http.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.Logger.WithError(err).Error("ACME HTTP-01 challenge server failed") } }() tlsConfig = s.acme.TLSConfig() @@ -250,14 +257,9 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.Logger.Errorf("debug endpoint error: %v", err) } }() - defer func() { - if err := s.debug.Close(); err != nil { - s.Logger.Debugf("debug endpoint close: %v", err) - } - }() } - // Start health probe server on separate port for Kubernetes probes. + // Start health probe server. healthAddr := s.HealthAddress if healthAddr == "" { healthAddr = "localhost:8080" @@ -272,30 +274,120 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { s.Logger.Errorf("health probe server: %v", err) } }() - defer func() { - shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := s.healthServer.Shutdown(shutdownCtx); err != nil { - s.Logger.Debugf("health probe server shutdown: %v", err) - } - }() - defer func() { - stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - if err := s.netbird.StopAll(stopCtx); err != nil { - s.Logger.Warnf("failed to stop all netbird clients: %v", err) - } - }() - - // Finally, start the reverse proxy. + // Start the reverse proxy HTTPS server. s.https = &http.Server{ Addr: addr, Handler: accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy))), TLSConfig: tlsConfig, } - s.Logger.Debugf("starting listening on reverse proxy server address %s", addr) - return s.https.ListenAndServeTLS("", "") + + httpsErr := make(chan error, 1) + go func() { + s.Logger.Debugf("starting reverse proxy server on %s", addr) + httpsErr <- s.https.ListenAndServeTLS("", "") + }() + + select { + case err := <-httpsErr: + s.shutdownServices() + if !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("https server: %w", err) + } + return nil + case <-ctx.Done(): + s.gracefulShutdown() + return nil + } +} + +const ( + // shutdownPreStopDelay is the time to wait after receiving a shutdown signal + // before draining connections. This allows the load balancer to propagate + // the endpoint removal. + shutdownPreStopDelay = 5 * time.Second + + // shutdownDrainTimeout is the maximum time to wait for in-flight HTTP + // requests to complete during graceful shutdown. + shutdownDrainTimeout = 30 * time.Second + + // shutdownServiceTimeout is the maximum time to wait for auxiliary + // services (health probe, debug endpoint, ACME) to shut down. + shutdownServiceTimeout = 5 * time.Second +) + +// gracefulShutdown performs a zero-downtime shutdown sequence. It marks the +// readiness probe as failing, waits for load balancer propagation, drains +// in-flight connections, and then stops all background services. +func (s *Server) gracefulShutdown() { + s.Logger.Info("shutdown signal received, starting graceful shutdown") + + // Step 1: Fail readiness probe so load balancers stop routing new traffic. + if s.healthChecker != nil { + s.healthChecker.SetShuttingDown() + } + + // Step 2: When running behind a load balancer, wait for endpoint removal + // to propagate before draining connections. + if k8s.InCluster() { + s.Logger.Infof("waiting %s for load balancer propagation", shutdownPreStopDelay) + time.Sleep(shutdownPreStopDelay) + } + + // Step 3: Stop accepting new connections and drain in-flight requests. + drainCtx, drainCancel := context.WithTimeout(context.Background(), shutdownDrainTimeout) + defer drainCancel() + + s.Logger.Info("draining in-flight connections") + if err := s.https.Shutdown(drainCtx); err != nil { + s.Logger.Warnf("https server drain: %v", err) + } + + // Step 4: Stop all remaining background services. + s.shutdownServices() + s.Logger.Info("graceful shutdown complete") +} + +// shutdownServices stops all background services concurrently and waits for +// them to finish. +func (s *Server) shutdownServices() { + var wg sync.WaitGroup + + shutdownHTTP := func(name string, shutdown func(context.Context) error) { + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), shutdownServiceTimeout) + defer cancel() + if err := shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.Logger.Debugf("%s shutdown: %v", name, err) + } + }() + } + + if s.healthServer != nil { + shutdownHTTP("health probe", s.healthServer.Shutdown) + } + if s.debug != nil { + shutdownHTTP("debug endpoint", s.debug.Shutdown) + } + if s.http != nil { + shutdownHTTP("acme http", s.http.Shutdown) + } + + if s.netbird != nil { + wg.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), shutdownDrainTimeout) + defer cancel() + if err := s.netbird.StopAll(ctx); err != nil { + s.Logger.Warnf("stop netbird clients: %v", err) + } + }() + } + + wg.Wait() } func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) {