Add health checks

This commit is contained in:
Viktor Liu
2026-02-04 21:23:00 +08:00
parent eeabc64a73
commit 7d844b9410
10 changed files with 748 additions and 54 deletions

View File

@@ -162,6 +162,7 @@ func New(opts Options) (*Client, error) {
setupKey: opts.SetupKey,
jwtToken: opts.JWTToken,
config: config,
recorder: peer.NewRecorder(config.ManagementURL.String()),
}, nil
}
@@ -183,6 +184,7 @@ func (c *Client) Start(startCtx context.Context) error {
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
@@ -192,10 +194,7 @@ func (c *Client) Start(startCtx context.Context) error {
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
@@ -348,14 +347,9 @@ func (c *Client) NewHTTPClient() *http.Client {
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
@@ -363,7 +357,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
}
}
return recorder.GetFullStatus(), nil
return c.recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.

View File

@@ -1,5 +1,24 @@
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-proxy"]
CMD []
COPY netbird-proxy /go/bin/netbird-proxy
FROM golang:1.25-alpine AS builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o netbird-proxy ./proxy/cmd/proxy
RUN echo "netbird:x:1000:1000:netbird:/var/lib/netbird:/sbin/nologin" > /tmp/passwd && \
echo "netbird:x:1000:netbird" > /tmp/group && \
mkdir -p /tmp/var/lib/netbird
FROM gcr.io/distroless/base:debug
COPY --from=builder /app/netbird-proxy /usr/bin/netbird-proxy
COPY --from=builder /tmp/passwd /etc/passwd
COPY --from=builder /tmp/group /etc/group
COPY --from=builder --chown=1000:1000 /tmp/var/lib/netbird /var/lib/netbird
USER netbird:netbird
ENV HOME=/var/lib/netbird
ENV NB_PROXY_ADDRESS=":8443"
ENV NB_PROXY_HEALTH_ADDRESS=":8080"
EXPOSE 8443
ENTRYPOINT ["/usr/bin/netbird-proxy"]

View File

@@ -34,6 +34,7 @@ var (
acmeDir string
debugEndpoint bool
debugEndpointAddr string
healthAddr string
oidcClientID string
oidcClientSecret string
oidcEndpoint string
@@ -59,6 +60,7 @@ func init() {
rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory")
rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint")
rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint")
rootCmd.Flags().StringVar(&healthAddr, "health-addr", envStringOrDefault("NB_PROXY_HEALTH_ADDRESS", "localhost:8080"), "Address for the health probe endpoint (liveness/readiness/startup)")
rootCmd.Flags().StringVar(&oidcClientID, "oidc-id", envStringOrDefault("NB_PROXY_OIDC_CLIENT_ID", "netbird-proxy"), "The OAuth2 Client ID for OIDC User Authentication")
rootCmd.Flags().StringVar(&oidcClientSecret, "oidc-secret", envStringOrDefault("NB_PROXY_OIDC_CLIENT_SECRET", ""), "The OAuth2 Client Secret for OIDC User Authentication")
rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication")
@@ -104,6 +106,7 @@ func runServer(cmd *cobra.Command, args []string) error {
ACMEDirectory: acmeDir,
DebugEndpointEnabled: debugEndpoint,
DebugEndpointAddress: debugEndpointAddr,
HealthAddress: healthAddr,
OIDCClientId: oidcClientID,
OIDCClientSecret: oidcClientSecret,
OIDCEndpoint: oidcEndpoint,

View File

@@ -0,0 +1,108 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: netbird-proxy
labels:
app: netbird-proxy
spec:
replicas: 1
selector:
matchLabels:
app: netbird-proxy
template:
metadata:
labels:
app: netbird-proxy
spec:
hostAliases:
- ip: "192.168.100.1"
hostnames:
- "host.docker.internal"
containers:
- name: proxy
image: netbird-proxy
ports:
- containerPort: 8443
name: https
- containerPort: 8080
name: health
- containerPort: 8444
name: debug
env:
- name: USER
value: "netbird"
- name: HOME
value: "/tmp"
- name: NB_PROXY_DEBUG_LOGS
value: "true"
- name: NB_PROXY_MANAGEMENT_ADDRESS
value: "http://host.docker.internal:8080"
- name: NB_PROXY_ADDRESS
value: ":8443"
- name: NB_PROXY_HEALTH_ADDRESS
value: ":8080"
- name: NB_PROXY_DEBUG_ENDPOINT
value: "true"
- name: NB_PROXY_DEBUG_ENDPOINT_ADDRESS
value: ":8444"
- name: NB_PROXY_URL
value: "https://proxy.local"
- name: NB_PROXY_CERTIFICATE_DIRECTORY
value: "/certs"
volumeMounts:
- name: tls-certs
mountPath: /certs
readOnly: true
livenessProbe:
httpGet:
path: /healthz/live
port: health
initialDelaySeconds: 5
periodSeconds: 10
timeoutSeconds: 5
failureThreshold: 3
readinessProbe:
httpGet:
path: /healthz/ready
port: health
initialDelaySeconds: 5
periodSeconds: 10
timeoutSeconds: 5
failureThreshold: 3
startupProbe:
httpGet:
path: /healthz/startup
port: health
periodSeconds: 2
timeoutSeconds: 10
failureThreshold: 60
resources:
requests:
memory: "64Mi"
cpu: "100m"
limits:
memory: "256Mi"
cpu: "500m"
volumes:
- name: tls-certs
secret:
secretName: netbird-proxy-tls
---
apiVersion: v1
kind: Service
metadata:
name: netbird-proxy
spec:
selector:
app: netbird-proxy
ports:
- name: https
port: 8443
targetPort: 8443
- name: health
port: 8080
targetPort: 8080
- name: debug
port: 8444
targetPort: 8444
type: ClusterIP

View File

@@ -0,0 +1,11 @@
kind: Cluster
apiVersion: kind.x-k8s.io/v1alpha4
nodes:
- role: control-plane
extraPortMappings:
- containerPort: 30080
hostPort: 30080
protocol: TCP
- containerPort: 30443
hostPort: 30443
protocol: TCP

View File

@@ -46,15 +46,15 @@ func formatDuration(d time.Duration) string {
}
}
// ClientProvider provides access to NetBird clients.
type ClientProvider interface {
// clientProvider provides access to NetBird clients.
type clientProvider interface {
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
}
// Handler provides HTTP debug endpoints.
type Handler struct {
provider ClientProvider
provider clientProvider
logger *log.Logger
startTime time.Time
templates *template.Template
@@ -62,7 +62,7 @@ type Handler struct {
}
// NewHandler creates a new debug handler.
func NewHandler(provider ClientProvider, logger *log.Logger) *Handler {
func NewHandler(provider clientProvider, logger *log.Logger) *Handler {
if logger == nil {
logger = log.StandardLogger()
}

View File

@@ -0,0 +1,340 @@
// Package health provides health probes for the proxy server.
package health
import (
"context"
"encoding/json"
"net"
"net/http"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy/internal/types"
)
const (
maxConcurrentChecks = 3
maxClientCheckTimeout = 5 * time.Minute
)
// clientProvider provides access to NetBird clients for health checks.
type clientProvider interface {
ListClientsForStartup() map[types.AccountID]*embed.Client
}
// Checker tracks health state and provides probe endpoints.
type Checker struct {
logger *log.Logger
provider clientProvider
mu sync.RWMutex
managementConnected bool
initialSyncComplete bool
// checkSem limits concurrent client health checks.
checkSem chan struct{}
}
// ClientHealth represents the health status of a single NetBird client.
type ClientHealth struct {
Healthy bool `json:"healthy"`
ManagementConnected bool `json:"management_connected"`
SignalConnected bool `json:"signal_connected"`
RelaysConnected int `json:"relays_connected"`
RelaysTotal int `json:"relays_total"`
Error string `json:"error,omitempty"`
}
// ProbeResponse represents the JSON response for health probes.
type ProbeResponse struct {
Status string `json:"status"`
Checks map[string]bool `json:"checks,omitempty"`
Clients map[types.AccountID]ClientHealth `json:"clients,omitempty"`
}
// Server runs the health probe HTTP server on a dedicated port.
type Server struct {
server *http.Server
logger *log.Logger
checker *Checker
}
// SetManagementConnected updates the management connection state.
func (c *Checker) SetManagementConnected(connected bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.managementConnected = connected
}
// SetInitialSyncComplete marks that the initial mapping sync has completed.
func (c *Checker) SetInitialSyncComplete() {
c.mu.Lock()
defer c.mu.Unlock()
c.initialSyncComplete = true
}
// CheckClientsConnected verifies all clients are connected to management/signal/relay.
// Uses the provided context for timeout/cancellation, with a maximum bound of maxClientCheckTimeout.
// Limits concurrent checks via semaphore.
func (c *Checker) CheckClientsConnected(ctx context.Context) (bool, map[types.AccountID]ClientHealth) {
// Apply upper bound timeout in case parent context has no deadline
ctx, cancel := context.WithTimeout(ctx, maxClientCheckTimeout)
defer cancel()
clients := c.provider.ListClientsForStartup()
// No clients yet means not ready
if len(clients) == 0 {
return false, make(map[types.AccountID]ClientHealth)
}
type result struct {
accountID types.AccountID
health ClientHealth
}
resultsCh := make(chan result, len(clients))
var wg sync.WaitGroup
for accountID, client := range clients {
wg.Add(1)
go func(id types.AccountID, cl *embed.Client) {
defer wg.Done()
// Acquire semaphore
select {
case c.checkSem <- struct{}{}:
defer func() { <-c.checkSem }()
case <-ctx.Done():
resultsCh <- result{id, ClientHealth{Healthy: false, Error: ctx.Err().Error()}}
return
}
resultsCh <- result{id, checkClientHealth(cl)}
}(accountID, client)
}
go func() {
wg.Wait()
close(resultsCh)
}()
results := make(map[types.AccountID]ClientHealth)
allHealthy := true
for r := range resultsCh {
results[r.accountID] = r.health
if !r.health.Healthy {
allHealthy = false
}
}
return allHealthy, results
}
// LivenessProbe returns true if the process is alive.
// This should always return true if we can respond.
func (c *Checker) LivenessProbe() bool {
return true
}
// ReadinessProbe returns true if the server can accept traffic.
func (c *Checker) ReadinessProbe() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.managementConnected
}
// StartupProbe checks if initial startup is complete.
// Checks management connection, initial sync, and all client health directly.
// Uses the provided context for timeout/cancellation.
func (c *Checker) StartupProbe(ctx context.Context) bool {
c.mu.RLock()
mgmt := c.managementConnected
sync := c.initialSyncComplete
c.mu.RUnlock()
if !mgmt || !sync {
return false
}
// Check all clients are connected to management/signal/relay
allHealthy, _ := c.CheckClientsConnected(ctx)
return allHealthy
}
// Handler returns an http.Handler for health probe endpoints.
func (c *Checker) Handler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/healthz/live", c.handleLiveness)
mux.HandleFunc("/healthz/ready", c.handleReadiness)
mux.HandleFunc("/healthz/startup", c.handleStartup)
mux.HandleFunc("/healthz", c.handleFull)
return mux
}
func (c *Checker) handleLiveness(w http.ResponseWriter, r *http.Request) {
if c.LivenessProbe() {
c.writeProbeResponse(w, http.StatusOK, "ok", nil, nil)
return
}
c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", nil, nil)
}
func (c *Checker) handleReadiness(w http.ResponseWriter, r *http.Request) {
c.mu.RLock()
checks := map[string]bool{
"management_connected": c.managementConnected,
}
c.mu.RUnlock()
if c.ReadinessProbe() {
c.writeProbeResponse(w, http.StatusOK, "ok", checks, nil)
return
}
c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", checks, nil)
}
func (c *Checker) handleStartup(w http.ResponseWriter, r *http.Request) {
c.mu.RLock()
mgmt := c.managementConnected
sync := c.initialSyncComplete
c.mu.RUnlock()
// Check clients directly using request context
allClientsHealthy, clientHealth := c.CheckClientsConnected(r.Context())
checks := map[string]bool{
"management_connected": mgmt,
"initial_sync_complete": sync,
"all_clients_healthy": allClientsHealthy,
}
if c.StartupProbe(r.Context()) {
c.writeProbeResponse(w, http.StatusOK, "ok", checks, clientHealth)
return
}
c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", checks, clientHealth)
}
func (c *Checker) handleFull(w http.ResponseWriter, r *http.Request) {
c.mu.RLock()
mgmt := c.managementConnected
sync := c.initialSyncComplete
c.mu.RUnlock()
allClientsHealthy, clientHealth := c.CheckClientsConnected(r.Context())
checks := map[string]bool{
"management_connected": mgmt,
"initial_sync_complete": sync,
"all_clients_healthy": allClientsHealthy,
}
status := "ok"
statusCode := http.StatusOK
if !c.ReadinessProbe() {
status = "fail"
statusCode = http.StatusServiceUnavailable
}
c.writeProbeResponse(w, statusCode, status, checks, clientHealth)
}
func (c *Checker) writeProbeResponse(w http.ResponseWriter, statusCode int, status string, checks map[string]bool, clients map[types.AccountID]ClientHealth) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
resp := ProbeResponse{
Status: status,
Checks: checks,
Clients: clients,
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
c.logger.Debugf("write health response: %v", err)
}
}
// ListenAndServe starts the health probe server.
func (s *Server) ListenAndServe() error {
s.logger.Infof("starting health probe server on %s", s.server.Addr)
return s.server.ListenAndServe()
}
// Serve starts the health probe server on the given listener.
func (s *Server) Serve(l net.Listener) error {
s.logger.Infof("starting health probe server on %s", l.Addr())
return s.server.Serve(l)
}
// Shutdown gracefully shuts down the health probe server.
func (s *Server) Shutdown(ctx context.Context) error {
return s.server.Shutdown(ctx)
}
// NewChecker creates a new health checker.
func NewChecker(logger *log.Logger, provider clientProvider) *Checker {
if logger == nil {
logger = log.StandardLogger()
}
return &Checker{
logger: logger,
provider: provider,
checkSem: make(chan struct{}, maxConcurrentChecks),
}
}
// NewServer creates a new health probe server.
func NewServer(addr string, checker *Checker, logger *log.Logger) *Server {
if logger == nil {
logger = log.StandardLogger()
}
return &Server{
server: &http.Server{
Addr: addr,
Handler: checker.Handler(),
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
},
logger: logger,
checker: checker,
}
}
func checkClientHealth(client *embed.Client) ClientHealth {
status, err := client.Status()
if err != nil {
return ClientHealth{
Healthy: false,
Error: err.Error(),
}
}
// Count only rel:// and rels:// relays (not stun/turn)
var relayCount, relaysConnected int
for _, relay := range status.Relays {
if !strings.HasPrefix(relay.URI, "rel://") && !strings.HasPrefix(relay.URI, "rels://") {
continue
}
relayCount++
if relay.Err == nil {
relaysConnected++
}
}
// Client is healthy if connected to management, signal, and at least one relay (if any are defined)
healthy := status.ManagementState.Connected &&
status.SignalState.Connected &&
(relayCount == 0 || relaysConnected > 0)
return ClientHealth{
Healthy: healthy,
ManagementConnected: status.ManagementState.Connected,
SignalConnected: status.SignalState.Connected,
RelaysConnected: relaysConnected,
RelaysTotal: relayCount,
}
}

View File

@@ -0,0 +1,155 @@
package health
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy/internal/types"
)
type mockClientProvider struct {
clients map[types.AccountID]*embed.Client
}
func (m *mockClientProvider) ListClientsForStartup() map[types.AccountID]*embed.Client {
return m.clients
}
func TestChecker_LivenessProbe(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
// Liveness should always return true if we can respond.
assert.True(t, checker.LivenessProbe())
}
func TestChecker_ReadinessProbe(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
// Initially not ready (management not connected).
assert.False(t, checker.ReadinessProbe())
// After management connects, should be ready.
checker.SetManagementConnected(true)
assert.True(t, checker.ReadinessProbe())
// If management disconnects, should not be ready.
checker.SetManagementConnected(false)
assert.False(t, checker.ReadinessProbe())
}
func TestChecker_StartupProbe_NoClients(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
// Initially startup not complete.
assert.False(t, checker.StartupProbe(context.Background()))
// Just management connected is not enough.
checker.SetManagementConnected(true)
assert.False(t, checker.StartupProbe(context.Background()))
// Management + initial sync but no clients = not ready
checker.SetInitialSyncComplete()
assert.False(t, checker.StartupProbe(context.Background()))
}
func TestChecker_Handler_Liveness(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "ok", resp.Status)
}
func TestChecker_Handler_Readiness_NotReady(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "fail", resp.Status)
assert.False(t, resp.Checks["management_connected"])
}
func TestChecker_Handler_Readiness_Ready(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "ok", resp.Status)
assert.True(t, resp.Checks["management_connected"])
}
func TestChecker_Handler_Startup_NotComplete(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/startup", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "fail", resp.Status)
}
func TestChecker_Handler_Full(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "ok", resp.Status)
assert.NotNil(t, resp.Checks)
// Clients may be empty map when no clients exist.
assert.Empty(t, resp.Clients)
}
func TestChecker_StartupProbe_RespectsContext(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
checker.SetInitialSyncComplete()
// Cancelled context should return false quickly
ctx, cancel := context.WithCancel(context.Background())
cancel()
result := checker.StartupProbe(ctx)
assert.False(t, result)
}

View File

@@ -56,20 +56,18 @@ type NetBird struct {
statusNotifier statusNotifier
}
// NewNetBird creates a new NetBird transport.
func NewNetBird(mgmtAddr, proxyID string, logger *log.Logger, notifier statusNotifier) *NetBird {
if logger == nil {
logger = log.StandardLogger()
}
return &NetBird{
mgmtAddr: mgmtAddr,
proxyID: proxyID,
logger: logger,
clients: make(map[types.AccountID]*clientEntry),
statusNotifier: notifier,
}
// ClientDebugInfo contains debug information about a client.
type ClientDebugInfo struct {
AccountID types.AccountID
DomainCount int
Domains domain.List
HasClient bool
CreatedAt time.Time
}
// accountIDContextKey is the context key for storing the account ID.
type accountIDContextKey struct{}
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
// one is created using the provided setup key. Multiple domains can share the same client.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, key, reverseProxyID string) error {
@@ -379,15 +377,6 @@ func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
return entry.client, true
}
// ClientDebugInfo contains debug information about a client.
type ClientDebugInfo struct {
AccountID types.AccountID
DomainCount int
Domains domain.List
HasClient bool
CreatedAt time.Time
}
// ListClientsForDebug returns information about all clients for debug purposes.
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
n.clientsMux.RLock()
@@ -410,8 +399,33 @@ func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
return result
}
// accountIDContextKey is the context key for storing the account ID.
type accountIDContextKey struct{}
// ListClientsForStartup returns all embed.Client instances for health checks.
func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
result := make(map[types.AccountID]*embed.Client)
for accountID, entry := range n.clients {
if entry.client != nil {
result[accountID] = entry.client
}
}
return result
}
// NewNetBird creates a new NetBird transport.
func NewNetBird(mgmtAddr, proxyID string, logger *log.Logger, notifier statusNotifier) *NetBird {
if logger == nil {
logger = log.StandardLogger()
}
return &NetBird{
mgmtAddr: mgmtAddr,
proxyID: proxyID,
logger: logger,
clients: make(map[types.AccountID]*clientEntry),
statusNotifier: notifier,
}
}
// WithAccountID adds the account ID to the context.
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {

View File

@@ -32,6 +32,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/debug"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
@@ -41,14 +42,16 @@ import (
)
type Server struct {
mgmtClient proto.ProxyServiceClient
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
mgmtClient proto.ProxyServiceClient
proxy *proxy.ReverseProxy
netbird *roundtrip.NetBird
acme *acme.Manager
auth *auth.Middleware
http *http.Server
https *http.Server
debug *http.Server
healthServer *health.Server
healthChecker *health.Checker
// Mostly used for debugging on management.
startTime time.Time
@@ -71,6 +74,8 @@ type Server struct {
DebugEndpointEnabled bool
// DebugEndpointAddress is the address for the debug HTTP endpoint (default: ":8444").
DebugEndpointAddress string
// HealthAddress is the address for the health probe endpoint (default: "localhost:8080").
HealthAddress string
}
// NotifyStatus sends a status update to management about tunnel connectivity
@@ -221,7 +226,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
Handler: debugHandler,
}
go func() {
s.Logger.WithField("address", debugAddr).Info("starting debug endpoint")
s.Logger.Infof("starting debug endpoint on %s", debugAddr)
if err := s.debug.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("debug endpoint error: %v", err)
}
@@ -233,6 +238,30 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
}()
}
// Start health probe server on separate port for Kubernetes probes.
healthAddr := s.HealthAddress
if healthAddr == "" {
healthAddr = "localhost:8080"
}
s.healthChecker = health.NewChecker(s.Logger, s.netbird)
s.healthServer = health.NewServer(healthAddr, s.healthChecker, s.Logger)
healthListener, err := net.Listen("tcp", healthAddr)
if err != nil {
return fmt.Errorf("health probe server listen on %s: %w", healthAddr, err)
}
go func() {
if err := s.healthServer.Serve(healthListener); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.Logger.Errorf("health probe server: %v", err)
}
}()
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.healthServer.Shutdown(shutdownCtx); err != nil {
s.Logger.Debugf("health probe server shutdown: %v", err)
}
}()
defer func() {
stopCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
@@ -252,8 +281,15 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.ProxyServiceClient) {
b := backoff.New(0, 0)
initialSyncDone := false
for {
s.Logger.Debug("Getting mapping updates from management server")
// Mark management as disconnected while we're attempting to reconnect.
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false)
}
mappingClient, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: s.ID,
Version: s.Version,
@@ -270,8 +306,14 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
time.Sleep(backoffDuration)
continue
}
// Mark management as connected once stream is established.
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(true)
}
s.Logger.Debug("Got mapping updates client from management server")
err = s.handleMappingStream(ctx, mappingClient)
err = s.handleMappingStream(ctx, mappingClient, &initialSyncDone)
backoffDuration := b.Duration()
switch {
case errors.Is(err, context.Canceled),
@@ -294,7 +336,7 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
}
}
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient) error {
func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.ProxyService_GetMappingUpdateClient, initialSyncDone *bool) error {
for {
// Check for context completion to gracefully shutdown.
select {
@@ -338,6 +380,14 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
}
}
s.Logger.Debug("Processing mapping update completed")
// After the first mapping sync, mark initial sync complete.
// Client health is checked directly in the startup probe.
if !*initialSyncDone && s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
*initialSyncDone = true
s.Logger.Info("Initial mapping sync complete")
}
}
}
}
@@ -376,7 +426,7 @@ func (s *Server) updateMapping(ctx context.Context, mapping *proto.ProxyMapping)
}
if mapping.GetAuth().GetOidc() != nil {
oidc := mapping.GetAuth().GetOidc()
schemes = append(schemes, auth.NewOIDC(mgmtClient, mapping.GetId(), mapping.GetAccountId(), auth.OIDCConfig{
schemes = append(schemes, auth.NewOIDC(s.mgmtClient, mapping.GetId(), mapping.GetAccountId(), auth.OIDCConfig{
Issuer: oidc.GetIssuer(),
Audiences: oidc.GetAudiences(),
KeysLocation: oidc.GetKeysLocation(),