From 7d844b94107f7198da8ffd10fba227980a8ed729 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 4 Feb 2026 21:23:00 +0800 Subject: [PATCH] Add health checks --- client/embed/embed.go | 14 +- proxy/Dockerfile | 29 ++- proxy/cmd/proxy/cmd/root.go | 3 + proxy/deploy/k8s/deployment.yaml | 108 +++++++++ proxy/deploy/kind-config.yaml | 11 + proxy/internal/debug/handler.go | 8 +- proxy/internal/health/health.go | 340 +++++++++++++++++++++++++++ proxy/internal/health/health_test.go | 155 ++++++++++++ proxy/internal/roundtrip/netbird.go | 60 +++-- proxy/server.go | 74 +++++- 10 files changed, 748 insertions(+), 54 deletions(-) create mode 100644 proxy/deploy/k8s/deployment.yaml create mode 100644 proxy/deploy/kind-config.yaml create mode 100644 proxy/internal/health/health.go create mode 100644 proxy/internal/health/health_test.go diff --git a/client/embed/embed.go b/client/embed/embed.go index 2ad025ff0..515d78d51 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -162,6 +162,7 @@ func New(opts Options) (*Client, error) { setupKey: opts.SetupKey, jwtToken: opts.JWTToken, config: config, + recorder: peer.NewRecorder(config.ManagementURL.String()), }, nil } @@ -183,6 +184,7 @@ func (c *Client) Start(startCtx context.Context) error { // nolint:staticcheck ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) + authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config) if err != nil { return fmt.Errorf("create auth client: %w", err) @@ -192,10 +194,7 @@ func (c *Client) Start(startCtx context.Context) error { if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil { return fmt.Errorf("login: %w", err) } - - recorder := peer.NewRecorder(c.config.ManagementURL.String()) - c.recorder = recorder - client := internal.NewConnectClient(ctx, c.config, recorder, false) + client := internal.NewConnectClient(ctx, c.config, c.recorder, false) client.SetSyncResponsePersistence(true) // either startup error (permanent backoff err) or nil err (successful engine up) @@ -348,14 +347,9 @@ func (c *Client) NewHTTPClient() *http.Client { // Status returns the current status of the client. func (c *Client) Status() (peer.FullStatus, error) { c.mu.Lock() - recorder := c.recorder connect := c.connect c.mu.Unlock() - if recorder == nil { - return peer.FullStatus{}, errors.New("client not started") - } - if connect != nil { engine := connect.Engine() if engine != nil { @@ -363,7 +357,7 @@ func (c *Client) Status() (peer.FullStatus, error) { } } - return recorder.GetFullStatus(), nil + return c.recorder.GetFullStatus(), nil } // GetLatestSyncResponse returns the latest sync response from the management server. diff --git a/proxy/Dockerfile b/proxy/Dockerfile index da5182ad1..89c9821c7 100644 --- a/proxy/Dockerfile +++ b/proxy/Dockerfile @@ -1,5 +1,24 @@ -FROM ubuntu:24.04 -RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt -ENTRYPOINT [ "/go/bin/netbird-proxy"] -CMD [] -COPY netbird-proxy /go/bin/netbird-proxy +FROM golang:1.25-alpine AS builder +WORKDIR /app + +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . +RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o netbird-proxy ./proxy/cmd/proxy + +RUN echo "netbird:x:1000:1000:netbird:/var/lib/netbird:/sbin/nologin" > /tmp/passwd && \ + echo "netbird:x:1000:netbird" > /tmp/group && \ + mkdir -p /tmp/var/lib/netbird + +FROM gcr.io/distroless/base:debug +COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy +COPY --from=builder /tmp/passwd /etc/passwd +COPY --from=builder /tmp/group /etc/group +COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird +USER netbird:netbird +ENV HOME=/var/lib/netbird +ENV NB_PROXY_ADDRESS=":8443" +ENV NB_PROXY_HEALTH_ADDRESS=":8080" +EXPOSE 8443 +ENTRYPOINT ["/usr/bin/netbird-proxy"] diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 3bff9a834..0a8cd6de5 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -34,6 +34,7 @@ var ( acmeDir string debugEndpoint bool debugEndpointAddr string + healthAddr string oidcClientID string oidcClientSecret string oidcEndpoint string @@ -59,6 +60,7 @@ func init() { rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory") rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint") rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint") + rootCmd.Flags().StringVar(&healthAddr, "health-addr", envStringOrDefault("NB_PROXY_HEALTH_ADDRESS", "localhost:8080"), "Address for the health probe endpoint (liveness/readiness/startup)") rootCmd.Flags().StringVar(&oidcClientID, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication") rootCmd.Flags().StringVar(&oidcClientSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication") rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication") @@ -104,6 +106,7 @@ func runServer(cmd *cobra.Command, args []string) error { ACMEDirectory: acmeDir, DebugEndpointEnabled: debugEndpoint, DebugEndpointAddress: debugEndpointAddr, + HealthAddress: healthAddr, OIDCClientId: oidcClientID, OIDCClientSecret: oidcClientSecret, OIDCEndpoint: oidcEndpoint, diff --git a/proxy/deploy/k8s/deployment.yaml b/proxy/deploy/k8s/deployment.yaml new file mode 100644 index 000000000..94b1e4e9e --- /dev/null +++ b/proxy/deploy/k8s/deployment.yaml @@ -0,0 +1,108 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: netbird-proxy + labels: + app: netbird-proxy +spec: + replicas: 1 + selector: + matchLabels: + app: netbird-proxy + template: + metadata: + labels: + app: netbird-proxy + spec: + hostAliases: + - ip: "192.168.100.1" + hostnames: + - "host.docker.internal" + containers: + - name: proxy + image: netbird-proxy + ports: + - containerPort: 8443 + name: https + - containerPort: 8080 + name: health + - containerPort: 8444 + name: debug + env: + - name: USER + value: "netbird" + - name: HOME + value: "/tmp" + - name: NB_PROXY_DEBUG_LOGS + value: "true" + - name: NB_PROXY_MANAGEMENT_ADDRESS + value: "http://host.docker.internal:8080" + - name: NB_PROXY_ADDRESS + value: ":8443" + - name: NB_PROXY_HEALTH_ADDRESS + value: ":8080" + - name: NB_PROXY_DEBUG_ENDPOINT + value: "true" + - name: NB_PROXY_DEBUG_ENDPOINT_ADDRESS + value: ":8444" + - name: NB_PROXY_URL + value: "https://proxy.local" + - name: NB_PROXY_CERTIFICATE_DIRECTORY + value: "/certs" + volumeMounts: + - name: tls-certs + mountPath: /certs + readOnly: true + livenessProbe: + httpGet: + path: /healthz/live + port: health + initialDelaySeconds: 5 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 3 + readinessProbe: + httpGet: + path: /healthz/ready + port: health + initialDelaySeconds: 5 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 3 + startupProbe: + httpGet: + path: /healthz/startup + port: health + periodSeconds: 2 + timeoutSeconds: 10 + failureThreshold: 60 + resources: + requests: + memory: "64Mi" + cpu: "100m" + limits: + memory: "256Mi" + cpu: "500m" + volumes: + - name: tls-certs + secret: + secretName: netbird-proxy-tls +--- +apiVersion: v1 +kind: Service +metadata: + name: netbird-proxy +spec: + selector: + app: netbird-proxy + ports: + - name: https + port: 8443 + targetPort: 8443 + - name: health + port: 8080 + targetPort: 8080 + - name: debug + port: 8444 + targetPort: 8444 + type: ClusterIP diff --git a/proxy/deploy/kind-config.yaml b/proxy/deploy/kind-config.yaml new file mode 100644 index 000000000..d40f1eb36 --- /dev/null +++ b/proxy/deploy/kind-config.yaml @@ -0,0 +1,11 @@ +kind: Cluster +apiVersion: kind.x-k8s.io/v1alpha4 +nodes: +- role: control-plane + extraPortMappings: + - containerPort: 30080 + hostPort: 30080 + protocol: TCP + - containerPort: 30443 + hostPort: 30443 + protocol: TCP diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index 43cfb9533..f7b1fa87c 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -46,15 +46,15 @@ func formatDuration(d time.Duration) string { } } -// ClientProvider provides access to NetBird clients. -type ClientProvider interface { +// clientProvider provides access to NetBird clients. +type clientProvider interface { GetClient(accountID types.AccountID) (*nbembed.Client, bool) ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo } // Handler provides HTTP debug endpoints. type Handler struct { - provider ClientProvider + provider clientProvider logger *log.Logger startTime time.Time templates *template.Template @@ -62,7 +62,7 @@ type Handler struct { } // NewHandler creates a new debug handler. -func NewHandler(provider ClientProvider, logger *log.Logger) *Handler { +func NewHandler(provider clientProvider, logger *log.Logger) *Handler { if logger == nil { logger = log.StandardLogger() } diff --git a/proxy/internal/health/health.go b/proxy/internal/health/health.go new file mode 100644 index 000000000..36ed51674 --- /dev/null +++ b/proxy/internal/health/health.go @@ -0,0 +1,340 @@ +// Package health provides health probes for the proxy server. +package health + +import ( + "context" + "encoding/json" + "net" + "net/http" + "strings" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/embed" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +const ( + maxConcurrentChecks = 3 + maxClientCheckTimeout = 5 * time.Minute +) + +// clientProvider provides access to NetBird clients for health checks. +type clientProvider interface { + ListClientsForStartup() map[types.AccountID]*embed.Client +} + +// Checker tracks health state and provides probe endpoints. +type Checker struct { + logger *log.Logger + provider clientProvider + + mu sync.RWMutex + managementConnected bool + initialSyncComplete bool + + // checkSem limits concurrent client health checks. + checkSem chan struct{} +} + +// ClientHealth represents the health status of a single NetBird client. +type ClientHealth struct { + Healthy bool `json:"healthy"` + ManagementConnected bool `json:"management_connected"` + SignalConnected bool `json:"signal_connected"` + RelaysConnected int `json:"relays_connected"` + RelaysTotal int `json:"relays_total"` + Error string `json:"error,omitempty"` +} + +// ProbeResponse represents the JSON response for health probes. +type ProbeResponse struct { + Status string `json:"status"` + Checks map[string]bool `json:"checks,omitempty"` + Clients map[types.AccountID]ClientHealth `json:"clients,omitempty"` +} + +// Server runs the health probe HTTP server on a dedicated port. +type Server struct { + server *http.Server + logger *log.Logger + checker *Checker +} + +// SetManagementConnected updates the management connection state. +func (c *Checker) SetManagementConnected(connected bool) { + c.mu.Lock() + defer c.mu.Unlock() + c.managementConnected = connected +} + +// SetInitialSyncComplete marks that the initial mapping sync has completed. +func (c *Checker) SetInitialSyncComplete() { + c.mu.Lock() + defer c.mu.Unlock() + c.initialSyncComplete = true +} + +// CheckClientsConnected verifies all clients are connected to management/signal/relay. +// Uses the provided context for timeout/cancellation, with a maximum bound of maxClientCheckTimeout. +// Limits concurrent checks via semaphore. +func (c *Checker) CheckClientsConnected(ctx context.Context) (bool, map[types.AccountID]ClientHealth) { + // Apply upper bound timeout in case parent context has no deadline + ctx, cancel := context.WithTimeout(ctx, maxClientCheckTimeout) + defer cancel() + + clients := c.provider.ListClientsForStartup() + + // No clients yet means not ready + if len(clients) == 0 { + return false, make(map[types.AccountID]ClientHealth) + } + + type result struct { + accountID types.AccountID + health ClientHealth + } + + resultsCh := make(chan result, len(clients)) + var wg sync.WaitGroup + + for accountID, client := range clients { + wg.Add(1) + go func(id types.AccountID, cl *embed.Client) { + defer wg.Done() + + // Acquire semaphore + select { + case c.checkSem <- struct{}{}: + defer func() { <-c.checkSem }() + case <-ctx.Done(): + resultsCh <- result{id, ClientHealth{Healthy: false, Error: ctx.Err().Error()}} + return + } + + resultsCh <- result{id, checkClientHealth(cl)} + }(accountID, client) + } + + go func() { + wg.Wait() + close(resultsCh) + }() + + results := make(map[types.AccountID]ClientHealth) + allHealthy := true + for r := range resultsCh { + results[r.accountID] = r.health + if !r.health.Healthy { + allHealthy = false + } + } + + return allHealthy, results +} + +// LivenessProbe returns true if the process is alive. +// This should always return true if we can respond. +func (c *Checker) LivenessProbe() bool { + return true +} + +// ReadinessProbe returns true if the server can accept traffic. +func (c *Checker) ReadinessProbe() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.managementConnected +} + +// StartupProbe checks if initial startup is complete. +// Checks management connection, initial sync, and all client health directly. +// Uses the provided context for timeout/cancellation. +func (c *Checker) StartupProbe(ctx context.Context) bool { + c.mu.RLock() + mgmt := c.managementConnected + sync := c.initialSyncComplete + c.mu.RUnlock() + + if !mgmt || !sync { + return false + } + + // Check all clients are connected to management/signal/relay + allHealthy, _ := c.CheckClientsConnected(ctx) + return allHealthy +} + +// Handler returns an http.Handler for health probe endpoints. +func (c *Checker) Handler() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/healthz/live", c.handleLiveness) + mux.HandleFunc("/healthz/ready", c.handleReadiness) + mux.HandleFunc("/healthz/startup", c.handleStartup) + mux.HandleFunc("/healthz", c.handleFull) + return mux +} + +func (c *Checker) handleLiveness(w http.ResponseWriter, r *http.Request) { + if c.LivenessProbe() { + c.writeProbeResponse(w, http.StatusOK, "ok", nil, nil) + return + } + c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", nil, nil) +} + +func (c *Checker) handleReadiness(w http.ResponseWriter, r *http.Request) { + c.mu.RLock() + checks := map[string]bool{ + "management_connected": c.managementConnected, + } + c.mu.RUnlock() + + if c.ReadinessProbe() { + c.writeProbeResponse(w, http.StatusOK, "ok", checks, nil) + return + } + c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", checks, nil) +} + +func (c *Checker) handleStartup(w http.ResponseWriter, r *http.Request) { + c.mu.RLock() + mgmt := c.managementConnected + sync := c.initialSyncComplete + c.mu.RUnlock() + + // Check clients directly using request context + allClientsHealthy, clientHealth := c.CheckClientsConnected(r.Context()) + + checks := map[string]bool{ + "management_connected": mgmt, + "initial_sync_complete": sync, + "all_clients_healthy": allClientsHealthy, + } + + if c.StartupProbe(r.Context()) { + c.writeProbeResponse(w, http.StatusOK, "ok", checks, clientHealth) + return + } + c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", checks, clientHealth) +} + +func (c *Checker) handleFull(w http.ResponseWriter, r *http.Request) { + c.mu.RLock() + mgmt := c.managementConnected + sync := c.initialSyncComplete + c.mu.RUnlock() + + allClientsHealthy, clientHealth := c.CheckClientsConnected(r.Context()) + + checks := map[string]bool{ + "management_connected": mgmt, + "initial_sync_complete": sync, + "all_clients_healthy": allClientsHealthy, + } + + status := "ok" + statusCode := http.StatusOK + if !c.ReadinessProbe() { + status = "fail" + statusCode = http.StatusServiceUnavailable + } + + c.writeProbeResponse(w, statusCode, status, checks, clientHealth) +} + +func (c *Checker) writeProbeResponse(w http.ResponseWriter, statusCode int, status string, checks map[string]bool, clients map[types.AccountID]ClientHealth) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + + resp := ProbeResponse{ + Status: status, + Checks: checks, + Clients: clients, + } + if err := json.NewEncoder(w).Encode(resp); err != nil { + c.logger.Debugf("write health response: %v", err) + } +} + +// ListenAndServe starts the health probe server. +func (s *Server) ListenAndServe() error { + s.logger.Infof("starting health probe server on %s", s.server.Addr) + return s.server.ListenAndServe() +} + +// Serve starts the health probe server on the given listener. +func (s *Server) Serve(l net.Listener) error { + s.logger.Infof("starting health probe server on %s", l.Addr()) + return s.server.Serve(l) +} + +// Shutdown gracefully shuts down the health probe server. +func (s *Server) Shutdown(ctx context.Context) error { + return s.server.Shutdown(ctx) +} + +// NewChecker creates a new health checker. +func NewChecker(logger *log.Logger, provider clientProvider) *Checker { + if logger == nil { + logger = log.StandardLogger() + } + return &Checker{ + logger: logger, + provider: provider, + checkSem: make(chan struct{}, maxConcurrentChecks), + } +} + +// NewServer creates a new health probe server. +func NewServer(addr string, checker *Checker, logger *log.Logger) *Server { + if logger == nil { + logger = log.StandardLogger() + } + return &Server{ + server: &http.Server{ + Addr: addr, + Handler: checker.Handler(), + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + }, + logger: logger, + checker: checker, + } +} + +func checkClientHealth(client *embed.Client) ClientHealth { + status, err := client.Status() + if err != nil { + return ClientHealth{ + Healthy: false, + Error: err.Error(), + } + } + + // Count only rel:// and rels:// relays (not stun/turn) + var relayCount, relaysConnected int + for _, relay := range status.Relays { + if !strings.HasPrefix(relay.URI, "rel://") && !strings.HasPrefix(relay.URI, "rels://") { + continue + } + relayCount++ + if relay.Err == nil { + relaysConnected++ + } + } + + // Client is healthy if connected to management, signal, and at least one relay (if any are defined) + healthy := status.ManagementState.Connected && + status.SignalState.Connected && + (relayCount == 0 || relaysConnected > 0) + + return ClientHealth{ + Healthy: healthy, + ManagementConnected: status.ManagementState.Connected, + SignalConnected: status.SignalState.Connected, + RelaysConnected: relaysConnected, + RelaysTotal: relayCount, + } +} diff --git a/proxy/internal/health/health_test.go b/proxy/internal/health/health_test.go new file mode 100644 index 000000000..d824bcf90 --- /dev/null +++ b/proxy/internal/health/health_test.go @@ -0,0 +1,155 @@ +package health + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/embed" + "github.com/netbirdio/netbird/proxy/internal/types" +) + +type mockClientProvider struct { + clients map[types.AccountID]*embed.Client +} + +func (m *mockClientProvider) ListClientsForStartup() map[types.AccountID]*embed.Client { + return m.clients +} + +func TestChecker_LivenessProbe(t *testing.T) { + checker := NewChecker(nil, &mockClientProvider{}) + + // Liveness should always return true if we can respond. + assert.True(t, checker.LivenessProbe()) +} + +func TestChecker_ReadinessProbe(t *testing.T) { + checker := NewChecker(nil, &mockClientProvider{}) + + // Initially not ready (management not connected). + assert.False(t, checker.ReadinessProbe()) + + // After management connects, should be ready. + checker.SetManagementConnected(true) + assert.True(t, checker.ReadinessProbe()) + + // If management disconnects, should not be ready. + checker.SetManagementConnected(false) + assert.False(t, checker.ReadinessProbe()) +} + +func TestChecker_StartupProbe_NoClients(t *testing.T) { + checker := NewChecker(nil, &mockClientProvider{}) + + // Initially startup not complete. + assert.False(t, checker.StartupProbe(context.Background())) + + // Just management connected is not enough. + checker.SetManagementConnected(true) + assert.False(t, checker.StartupProbe(context.Background())) + + // Management + initial sync but no clients = not ready + checker.SetInitialSyncComplete() + assert.False(t, checker.StartupProbe(context.Background())) +} + +func TestChecker_Handler_Liveness(t *testing.T) { + checker := NewChecker(nil, &mockClientProvider{}) + handler := checker.Handler() + + req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp ProbeResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp)) + assert.Equal(t, "ok", resp.Status) +} + +func TestChecker_Handler_Readiness_NotReady(t *testing.T) { + checker := NewChecker(nil, &mockClientProvider{}) + handler := checker.Handler() + + req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) + + var resp ProbeResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp)) + assert.Equal(t, "fail", resp.Status) + assert.False(t, resp.Checks["management_connected"]) +} + +func TestChecker_Handler_Readiness_Ready(t *testing.T) { + checker := NewChecker(nil, &mockClientProvider{}) + checker.SetManagementConnected(true) + handler := checker.Handler() + + req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp ProbeResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp)) + assert.Equal(t, "ok", resp.Status) + assert.True(t, resp.Checks["management_connected"]) +} + +func TestChecker_Handler_Startup_NotComplete(t *testing.T) { + checker := NewChecker(nil, &mockClientProvider{}) + handler := checker.Handler() + + req := httptest.NewRequest(http.MethodGet, "/healthz/startup", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusServiceUnavailable, rec.Code) + + var resp ProbeResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp)) + assert.Equal(t, "fail", resp.Status) +} + +func TestChecker_Handler_Full(t *testing.T) { + checker := NewChecker(nil, &mockClientProvider{}) + checker.SetManagementConnected(true) + handler := checker.Handler() + + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + + var resp ProbeResponse + require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp)) + assert.Equal(t, "ok", resp.Status) + assert.NotNil(t, resp.Checks) + // Clients may be empty map when no clients exist. + assert.Empty(t, resp.Clients) +} + +func TestChecker_StartupProbe_RespectsContext(t *testing.T) { + checker := NewChecker(nil, &mockClientProvider{}) + checker.SetManagementConnected(true) + checker.SetInitialSyncComplete() + + // Cancelled context should return false quickly + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + result := checker.StartupProbe(ctx) + assert.False(t, result) +} diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index b869dba88..fa5bb1a63 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -56,20 +56,18 @@ type NetBird struct { statusNotifier statusNotifier } -// NewNetBird creates a new NetBird transport. -func NewNetBird(mgmtAddr, proxyID string, logger *log.Logger, notifier statusNotifier) *NetBird { - if logger == nil { - logger = log.StandardLogger() - } - return &NetBird{ - mgmtAddr: mgmtAddr, - proxyID: proxyID, - logger: logger, - clients: make(map[types.AccountID]*clientEntry), - statusNotifier: notifier, - } +// ClientDebugInfo contains debug information about a client. +type ClientDebugInfo struct { + AccountID types.AccountID + DomainCount int + Domains domain.List + HasClient bool + CreatedAt time.Time } +// accountIDContextKey is the context key for storing the account ID. +type accountIDContextKey struct{} + // AddPeer registers a domain for an account. If the account doesn't have a client yet, // one is created using the provided setup key. Multiple domains can share the same client. func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, key, reverseProxyID string) error { @@ -379,15 +377,6 @@ func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) { return entry.client, true } -// ClientDebugInfo contains debug information about a client. -type ClientDebugInfo struct { - AccountID types.AccountID - DomainCount int - Domains domain.List - HasClient bool - CreatedAt time.Time -} - // ListClientsForDebug returns information about all clients for debug purposes. func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo { n.clientsMux.RLock() @@ -410,8 +399,33 @@ func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo { return result } -// accountIDContextKey is the context key for storing the account ID. -type accountIDContextKey struct{} +// ListClientsForStartup returns all embed.Client instances for health checks. +func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client { + n.clientsMux.RLock() + defer n.clientsMux.RUnlock() + + result := make(map[types.AccountID]*embed.Client) + for accountID, entry := range n.clients { + if entry.client != nil { + result[accountID] = entry.client + } + } + return result +} + +// NewNetBird creates a new NetBird transport. +func NewNetBird(mgmtAddr, proxyID string, logger *log.Logger, notifier statusNotifier) *NetBird { + if logger == nil { + logger = log.StandardLogger() + } + return &NetBird{ + mgmtAddr: mgmtAddr, + proxyID: proxyID, + logger: logger, + clients: make(map[types.AccountID]*clientEntry), + statusNotifier: notifier, + } +} // WithAccountID adds the account ID to the context. func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context { diff --git a/proxy/server.go b/proxy/server.go index ad1f4654c..a2123d4e8 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -32,6 +32,7 @@ import ( "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/proxy/internal/auth" "github.com/netbirdio/netbird/proxy/internal/debug" + "github.com/netbirdio/netbird/proxy/internal/health" "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/netbirdio/netbird/proxy/internal/roundtrip" "github.com/netbirdio/netbird/proxy/internal/types" @@ -41,14 +42,16 @@ import ( ) type Server struct { - mgmtClient proto.ProxyServiceClient - proxy *proxy.ReverseProxy - netbird *roundtrip.NetBird - acme *acme.Manager - auth *auth.Middleware - http *http.Server - https *http.Server - debug *http.Server + mgmtClient proto.ProxyServiceClient + proxy *proxy.ReverseProxy + netbird *roundtrip.NetBird + acme *acme.Manager + auth *auth.Middleware + http *http.Server + https *http.Server + debug *http.Server + healthServer *health.Server + healthChecker *health.Checker // Mostly used for debugging on management. startTime time.Time @@ -71,6 +74,8 @@ type Server struct { DebugEndpointEnabled bool // DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444"). DebugEndpointAddress string + // HealthAddress is the address for the health probe endpoint (default: "localhost:8080"). + HealthAddress string } // NotifyStatus sends a status update to management about tunnel connectivity @@ -221,7 +226,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { Handler: debugHandler, } go func() { - s.Logger.WithField("address", debugAddr).Info("starting debug endpoint") + s.Logger.Infof("starting debug endpoint on %s", debugAddr) if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { s.Logger.Errorf("debug endpoint error: %v", err) } @@ -233,6 +238,30 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { }() } + // Start health probe server on separate port for Kubernetes probes. + healthAddr := s.HealthAddress + if healthAddr == "" { + healthAddr = "localhost:8080" + } + s.healthChecker = health.NewChecker(s.Logger, s.netbird) + s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger) + healthListener, err := net.Listen("tcp", healthAddr) + if err != nil { + return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err) + } + go func() { + if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.Logger.Errorf("health probe server: %v", err) + } + }() + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.healthServer.Shutdown(shutdownCtx); err != nil { + s.Logger.Debugf("health probe server shutdown: %v", err) + } + }() + defer func() { stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -252,8 +281,15 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) { b := backoff.New(0, 0) + initialSyncDone := false for { s.Logger.Debug("Getting mapping updates from management server") + + // Mark management as disconnected while we're attempting to reconnect. + if s.healthChecker != nil { + s.healthChecker.SetManagementConnected(false) + } + mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{ ProxyId: s.ID, Version: s.Version, @@ -270,8 +306,14 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr time.Sleep(backoffDuration) continue } + + // Mark management as connected once stream is established. + if s.healthChecker != nil { + s.healthChecker.SetManagementConnected(true) + } s.Logger.Debug("Got mapping updates client from management server") - err = s.handleMappingStream(ctx, mappingClient) + + err = s.handleMappingStream(ctx, mappingClient, &initialSyncDone) backoffDuration := b.Duration() switch { case errors.Is(err, context.Canceled), @@ -294,7 +336,7 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr } } -func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient) error { +func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool) error { for { // Check for context completion to gracefully shutdown. select { @@ -338,6 +380,14 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr } } s.Logger.Debug("Processing mapping update completed") + + // After the first mapping sync, mark initial sync complete. + // Client health is checked directly in the startup probe. + if !*initialSyncDone && s.healthChecker != nil { + s.healthChecker.SetInitialSyncComplete() + *initialSyncDone = true + s.Logger.Info("Initial mapping sync complete") + } } } } @@ -376,7 +426,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping) } if mapping.GetAuth().GetOidc() != nil { oidc := mapping.GetAuth().GetOidc() - schemes = append(schemes, auth.NewOIDC(mgmtClient, mapping.GetId(), mapping.GetAccountId(), auth.OIDCConfig{ + schemes = append(schemes, auth.NewOIDC(s.mgmtClient, mapping.GetId(), mapping.GetAccountId(), auth.OIDCConfig{ Issuer: oidc.GetIssuer(), Audiences: oidc.GetAudiences(), KeysLocation: oidc.GetKeysLocation(),