[management, reverse proxy] Add reverse proxy feature (#5291)

* implement reverse proxy


---------

Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk>
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
Co-authored-by: Eduard Gert <kontakt@eduardgert.de>
Co-authored-by: Viktor Liu <viktor@netbird.io>
Co-authored-by: Diego Noguês <diego.sure@gmail.com>
Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
This commit is contained in:
Pascal Fischer
2026-02-13 19:37:43 +01:00
committed by GitHub
parent edce11b34d
commit f53155562f
225 changed files with 35513 additions and 235 deletions

View File

@@ -0,0 +1,105 @@
package accesslog
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
type gRPCClient interface {
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
}
// Logger sends access log entries to the management server via gRPC.
type Logger struct {
client gRPCClient
logger *log.Logger
trustedProxies []netip.Prefix
}
// NewLogger creates a new access log Logger. The trustedProxies parameter
// configures which upstream proxy IP ranges are trusted for extracting
// the real client IP from X-Forwarded-For headers.
func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Prefix) *Logger {
if logger == nil {
logger = log.StandardLogger()
}
return &Logger{
client: client,
logger: logger,
trustedProxies: trustedProxies,
}
}
type logEntry struct {
ID string
AccountID string
ServiceId string
Host string
Path string
DurationMs int64
Method string
ResponseCode int32
SourceIp string
AuthMechanism string
UserId string
AuthSuccess bool
}
func (l *Logger) log(ctx context.Context, entry logEntry) {
// Fire off the log request in a separate routine.
// This increases the possibility of losing a log message
// (although it should still get logged in the event of an error),
// but it will reduce latency returning the request in the
// middleware.
// There is also a chance that log messages will arrive at
// the server out of order; however, the timestamp should
// allow for resolving that on the server.
now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary.
go func() {
logCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if entry.AuthMechanism != auth.MethodOIDC.String() {
entry.UserId = ""
}
if _, err := l.client.SendAccessLog(logCtx, &proto.SendAccessLogRequest{
Log: &proto.AccessLog{
LogId: entry.ID,
AccountId: entry.AccountID,
Timestamp: now,
ServiceId: entry.ServiceId,
Host: entry.Host,
Path: entry.Path,
DurationMs: entry.DurationMs,
Method: entry.Method,
ResponseCode: entry.ResponseCode,
SourceIp: entry.SourceIp,
AuthMechanism: entry.AuthMechanism,
UserId: entry.UserId,
AuthSuccess: entry.AuthSuccess,
},
}); err != nil {
// If it fails to send on the gRPC connection, then at least log it to the error log.
l.logger.WithFields(log.Fields{
"service_id": entry.ServiceId,
"host": entry.Host,
"path": entry.Path,
"duration": entry.DurationMs,
"method": entry.Method,
"response_code": entry.ResponseCode,
"source_ip": entry.SourceIp,
"auth_mechanism": entry.AuthMechanism,
"user_id": entry.UserId,
"auth_success": entry.AuthSuccess,
"error": err,
}).Error("Error sending access log on gRPC connection")
}
}()
}

View File

@@ -0,0 +1,74 @@
package accesslog
import (
"net"
"net/http"
"strings"
"time"
"github.com/rs/xid"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/web"
)
func (l *Logger) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip logging for internal proxy assets (CSS, JS, etc.)
if strings.HasPrefix(r.URL.Path, web.PathPrefix+"/") {
next.ServeHTTP(w, r)
return
}
// Generate request ID early so it can be used by error pages and log correlation.
requestID := xid.New().String()
l.logger.Debugf("request: request_id=%s method=%s host=%s path=%s", requestID, r.Method, r.Host, r.URL.Path)
// Use a response writer wrapper so we can access the status code later.
sw := &statusWriter{
w: w,
status: http.StatusOK,
}
// Resolve the source IP using trusted proxy configuration before passing
// the request on, as the proxy will modify forwarding headers.
sourceIp := extractSourceIP(r, l.trustedProxies)
// Create a mutable struct to capture data from downstream handlers.
// We pass a pointer in the context - the pointer itself flows down immutably,
// but the struct it points to can be mutated by inner handlers.
capturedData := &proxy.CapturedData{RequestID: requestID}
capturedData.SetClientIP(sourceIp)
ctx := proxy.WithCapturedData(r.Context(), capturedData)
start := time.Now()
next.ServeHTTP(sw, r.WithContext(ctx))
duration := time.Since(start)
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
// Fallback to just using the full host value.
host = r.Host
}
entry := logEntry{
ID: requestID,
ServiceId: capturedData.GetServiceId(),
AccountID: string(capturedData.GetAccountId()),
Host: host,
Path: r.URL.Path,
DurationMs: duration.Milliseconds(),
Method: r.Method,
ResponseCode: int32(sw.status),
SourceIp: sourceIp,
AuthMechanism: capturedData.GetAuthMethod(),
UserId: capturedData.GetUserID(),
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
}
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
l.log(r.Context(), entry)
})
}

View File

@@ -0,0 +1,16 @@
package accesslog
import (
"net/http"
"net/netip"
"github.com/netbirdio/netbird/proxy/internal/proxy"
)
// extractSourceIP resolves the real client IP from the request using trusted
// proxy configuration. When trustedProxies is non-empty and the direct
// connection is from a trusted source, it walks X-Forwarded-For right-to-left
// skipping trusted IPs. Otherwise it returns RemoteAddr directly.
func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) string {
return proxy.ResolveClientIP(r.RemoteAddr, r.Header.Get("X-Forwarded-For"), trustedProxies)
}

View File

@@ -0,0 +1,26 @@
package accesslog
import (
"net/http"
)
// statusWriter is a simple wrapper around an http.ResponseWriter
// that captures the setting of the status code via the WriteHeader
// function and stores it so that it can be retrieved later.
type statusWriter struct {
w http.ResponseWriter
status int
}
func (w *statusWriter) Header() http.Header {
return w.w.Header()
}
func (w *statusWriter) Write(data []byte) (int, error) {
return w.w.Write(data)
}
func (w *statusWriter) WriteHeader(status int) {
w.status = status
w.w.WriteHeader(status)
}

View File

@@ -0,0 +1,102 @@
package acme
import (
"context"
"path/filepath"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/flock"
"github.com/netbirdio/netbird/proxy/internal/k8s"
)
// certLocker provides distributed mutual exclusion for certificate operations.
// Implementations must be safe for concurrent use from multiple goroutines.
type certLocker interface {
// Lock acquires an exclusive lock for the given domain.
// It blocks until the lock is acquired, the context is cancelled, or an
// unrecoverable error occurs. The returned function releases the lock;
// callers must call it exactly once when the critical section is complete.
Lock(ctx context.Context, domain string) (unlock func(), err error)
}
// CertLockMethod controls how ACME certificate locks are coordinated.
type CertLockMethod string
const (
// CertLockAuto detects the environment and selects k8s-lease if running
// in a Kubernetes pod, otherwise flock.
CertLockAuto CertLockMethod = "auto"
// CertLockFlock uses advisory file locks via flock(2).
CertLockFlock CertLockMethod = "flock"
// CertLockK8sLease uses Kubernetes coordination Leases.
CertLockK8sLease CertLockMethod = "k8s-lease"
)
func newCertLocker(method CertLockMethod, certDir string, logger *log.Logger) certLocker {
if logger == nil {
logger = log.StandardLogger()
}
if method == "" || method == CertLockAuto {
if k8s.InCluster() {
method = CertLockK8sLease
} else {
method = CertLockFlock
}
logger.Infof("auto-detected cert lock method: %s", method)
}
switch method {
case CertLockK8sLease:
locker, err := newK8sLeaseLocker(logger)
if err != nil {
logger.Warnf("create k8s lease locker, falling back to flock: %v", err)
return newFlockLocker(certDir, logger)
}
logger.Infof("using k8s lease locker in namespace %s", locker.client.Namespace())
return locker
default:
logger.Infof("using flock cert locker in %s", certDir)
return newFlockLocker(certDir, logger)
}
}
type flockLocker struct {
certDir string
logger *log.Logger
}
func newFlockLocker(certDir string, logger *log.Logger) *flockLocker {
if logger == nil {
logger = log.StandardLogger()
}
return &flockLocker{certDir: certDir, logger: logger}
}
// Lock acquires an advisory file lock for the given domain.
func (l *flockLocker) Lock(ctx context.Context, domain string) (func(), error) {
lockPath := filepath.Join(l.certDir, domain+".lock")
lockFile, err := flock.Lock(ctx, lockPath)
if err != nil {
return nil, err
}
// nil lockFile means locking is not supported (non-unix).
if lockFile == nil {
return func() { /* no-op: locking unsupported on this platform */ }, nil
}
return func() {
if err := flock.Unlock(lockFile); err != nil {
l.logger.Debugf("release cert lock for domain %q: %v", domain, err)
}
}, nil
}
type noopLocker struct{}
// Lock is a no-op that always succeeds immediately.
func (noopLocker) Lock(context.Context, string) (func(), error) {
return func() { /* no-op: locker disabled */ }, nil
}

View File

@@ -0,0 +1,197 @@
package acme
import (
"context"
"errors"
"fmt"
"os"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/k8s"
)
const (
// leaseDurationSec is the Kubernetes Lease TTL. If the holder crashes without
// releasing the lock, other replicas must wait this long before taking over.
// This is intentionally generous: in the worst case two replicas may both
// issue an ACME request for the same domain, which is harmless (the CA
// deduplicates and the cache converges).
leaseDurationSec = 300
retryBaseBackoff = 500 * time.Millisecond
retryMaxBackoff = 10 * time.Second
)
type k8sLeaseLocker struct {
client *k8s.LeaseClient
identity string
logger *log.Logger
}
func newK8sLeaseLocker(logger *log.Logger) (*k8sLeaseLocker, error) {
client, err := k8s.NewLeaseClient()
if err != nil {
return nil, fmt.Errorf("create k8s lease client: %w", err)
}
identity, err := os.Hostname()
if err != nil {
return nil, fmt.Errorf("get hostname: %w", err)
}
return &k8sLeaseLocker{
client: client,
identity: identity,
logger: logger,
}, nil
}
// Lock acquires a Kubernetes Lease for the given domain using optimistic
// concurrency. It retries with exponential backoff until the lease is
// acquired or the context is cancelled.
func (l *k8sLeaseLocker) Lock(ctx context.Context, domain string) (func(), error) {
leaseName := k8s.LeaseNameForDomain(domain)
backoff := retryBaseBackoff
for {
acquired, err := l.tryAcquire(ctx, leaseName, domain)
if err != nil {
return nil, fmt.Errorf("acquire lease %s for %q: %w", leaseName, domain, err)
}
if acquired {
l.logger.Debugf("k8s lease %s acquired for domain %q", leaseName, domain)
return l.unlockFunc(leaseName, domain), nil
}
l.logger.Debugf("k8s lease %s held by another replica, retrying in %s", leaseName, backoff)
timer := time.NewTimer(backoff)
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-timer.C:
}
backoff *= 2
if backoff > retryMaxBackoff {
backoff = retryMaxBackoff
}
}
}
// tryAcquire attempts to create or take over a Lease. Returns (true, nil)
// on success, (false, nil) if the lease is held and not stale, or an error.
func (l *k8sLeaseLocker) tryAcquire(ctx context.Context, name, domain string) (bool, error) {
existing, err := l.client.Get(ctx, name)
if err != nil {
return false, err
}
now := k8s.MicroTime{Time: time.Now().UTC()}
dur := int32(leaseDurationSec)
if existing == nil {
lease := &k8s.Lease{
Metadata: k8s.LeaseMetadata{
Name: name,
Annotations: map[string]string{
"netbird.io/domain": domain,
},
},
Spec: k8s.LeaseSpec{
HolderIdentity: &l.identity,
LeaseDurationSeconds: &dur,
AcquireTime: &now,
RenewTime: &now,
},
}
if _, err := l.client.Create(ctx, lease); errors.Is(err, k8s.ErrConflict) {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}
if !l.canTakeover(existing) {
return false, nil
}
existing.Spec.HolderIdentity = &l.identity
existing.Spec.LeaseDurationSeconds = &dur
existing.Spec.AcquireTime = &now
existing.Spec.RenewTime = &now
if _, err := l.client.Update(ctx, existing); errors.Is(err, k8s.ErrConflict) {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}
// canTakeover returns true if the lease is free (no holder) or stale
// (renewTime + leaseDuration has passed).
func (l *k8sLeaseLocker) canTakeover(lease *k8s.Lease) bool {
holder := lease.Spec.HolderIdentity
if holder == nil || *holder == "" {
return true
}
// We already hold it (e.g. from a previous crashed attempt).
if *holder == l.identity {
return true
}
if lease.Spec.RenewTime == nil || lease.Spec.LeaseDurationSeconds == nil {
return true
}
expiry := lease.Spec.RenewTime.Add(time.Duration(*lease.Spec.LeaseDurationSeconds) * time.Second)
if time.Now().After(expiry) {
l.logger.Infof("k8s lease %s held by %q is stale (expired %s ago), taking over",
lease.Metadata.Name, *holder, time.Since(expiry).Round(time.Second))
return true
}
return false
}
// unlockFunc returns a closure that releases the lease by clearing the holder.
func (l *k8sLeaseLocker) unlockFunc(name, domain string) func() {
return func() {
// Use a fresh context: the parent may already be cancelled.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Re-GET to get current resourceVersion (ours may be stale if
// the lock was held for a long time and something updated it).
current, err := l.client.Get(ctx, name)
if err != nil {
l.logger.Debugf("release k8s lease %s for %q: get: %v", name, domain, err)
return
}
if current == nil {
return
}
// Only clear if we're still the holder.
if current.Spec.HolderIdentity == nil || *current.Spec.HolderIdentity != l.identity {
l.logger.Debugf("k8s lease %s for %q: holder changed to %v, skip release",
name, domain, current.Spec.HolderIdentity)
return
}
empty := ""
current.Spec.HolderIdentity = &empty
current.Spec.AcquireTime = nil
current.Spec.RenewTime = nil
if _, err := l.client.Update(ctx, current); err != nil {
l.logger.Debugf("release k8s lease %s for %q: update: %v", name, domain, err)
}
}
}

View File

@@ -0,0 +1,65 @@
package acme
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFlockLockerRoundTrip(t *testing.T) {
dir := t.TempDir()
locker := newFlockLocker(dir, nil)
unlock, err := locker.Lock(context.Background(), "example.com")
require.NoError(t, err)
require.NotNil(t, unlock)
// Lock file should exist.
assert.FileExists(t, filepath.Join(dir, "example.com.lock"))
unlock()
}
func TestNoopLocker(t *testing.T) {
locker := noopLocker{}
unlock, err := locker.Lock(context.Background(), "example.com")
require.NoError(t, err)
require.NotNil(t, unlock)
unlock()
}
func TestNewCertLockerDefaultsToFlock(t *testing.T) {
dir := t.TempDir()
// t.Setenv registers cleanup to restore the original value.
// os.Unsetenv is needed because the production code uses LookupEnv,
// which distinguishes "empty" from "not set".
t.Setenv("KUBERNETES_SERVICE_HOST", "")
os.Unsetenv("KUBERNETES_SERVICE_HOST")
locker := newCertLocker(CertLockAuto, dir, nil)
_, ok := locker.(*flockLocker)
assert.True(t, ok, "auto without k8s env should select flockLocker")
}
func TestNewCertLockerExplicitFlock(t *testing.T) {
dir := t.TempDir()
locker := newCertLocker(CertLockFlock, dir, nil)
_, ok := locker.(*flockLocker)
assert.True(t, ok, "explicit flock should select flockLocker")
}
func TestNewCertLockerK8sFallsBackToFlock(t *testing.T) {
dir := t.TempDir()
// k8s-lease without SA files should fall back to flock.
locker := newCertLocker(CertLockK8sLease, dir, nil)
_, ok := locker.(*flockLocker)
assert.True(t, ok, "k8s-lease without SA should fall back to flockLocker")
}

View File

@@ -0,0 +1,336 @@
package acme
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"encoding/binary"
"fmt"
"net"
"slices"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
"github.com/netbirdio/netbird/shared/management/domain"
)
// OID for the SCT list extension (1.3.6.1.4.1.11129.2.4.2)
var oidSCTList = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 2}
type certificateNotifier interface {
NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error
}
type domainState int
const (
domainPending domainState = iota
domainReady
domainFailed
)
type domainInfo struct {
accountID string
serviceID string
state domainState
err string
}
// Manager wraps autocert.Manager with domain tracking and cross-replica
// coordination via a pluggable locking strategy. The locker prevents
// duplicate ACME requests when multiple replicas share a certificate cache.
type Manager struct {
*autocert.Manager
certDir string
locker certLocker
mu sync.RWMutex
domains map[domain.Domain]*domainInfo
certNotifier certificateNotifier
logger *log.Logger
}
// NewManager creates a new ACME certificate manager. The certDir is used
// for caching certificates. The lockMethod controls cross-replica
// coordination strategy (see CertLockMethod constants).
func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager {
if logger == nil {
logger = log.StandardLogger()
}
mgr := &Manager{
certDir: certDir,
locker: newCertLocker(lockMethod, certDir, logger),
domains: make(map[domain.Domain]*domainInfo),
certNotifier: notifier,
logger: logger,
}
mgr.Manager = &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: mgr.hostPolicy,
Cache: autocert.DirCache(certDir),
Client: &acme.Client{
DirectoryURL: acmeURL,
},
}
return mgr
}
func (mgr *Manager) hostPolicy(_ context.Context, host string) error {
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
mgr.mu.RLock()
_, exists := mgr.domains[domain.Domain(host)]
mgr.mu.RUnlock()
if !exists {
return fmt.Errorf("unknown domain %q", host)
}
return nil
}
// AddDomain registers a domain for ACME certificate prefetching.
func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) {
mgr.mu.Lock()
mgr.domains[d] = &domainInfo{
accountID: accountID,
serviceID: serviceID,
state: domainPending,
}
mgr.mu.Unlock()
go mgr.prefetchCertificate(d)
}
// prefetchCertificate proactively triggers certificate generation for a domain.
// It acquires a distributed lock to prevent multiple replicas from issuing
// duplicate ACME requests. The second replica will block until the first
// finishes, then find the certificate in the cache.
func (mgr *Manager) prefetchCertificate(d domain.Domain) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
name := d.PunycodeString()
mgr.logger.Infof("acquiring cert lock for domain %q", name)
lockStart := time.Now()
unlock, err := mgr.locker.Lock(ctx, name)
if err != nil {
mgr.logger.Warnf("acquire cert lock for domain %q, proceeding without lock: %v", name, err)
} else {
mgr.logger.Infof("acquired cert lock for domain %q in %s", name, time.Since(lockStart))
defer unlock()
}
hello := &tls.ClientHelloInfo{
ServerName: name,
Conn: &dummyConn{ctx: ctx},
}
start := time.Now()
cert, err := mgr.GetCertificate(hello)
elapsed := time.Since(start)
if err != nil {
mgr.logger.Warnf("prefetch certificate for domain %q: %v", name, err)
mgr.setDomainState(d, domainFailed, err.Error())
return
}
mgr.setDomainState(d, domainReady, "")
now := time.Now()
if cert != nil && cert.Leaf != nil {
leaf := cert.Leaf
mgr.logger.Infof("certificate for domain %q ready in %s: serial=%s SANs=%v notBefore=%s, notAfter=%s, now=%s",
name, elapsed.Round(time.Millisecond),
leaf.SerialNumber.Text(16),
leaf.DNSNames,
leaf.NotBefore.UTC().Format(time.RFC3339),
leaf.NotAfter.UTC().Format(time.RFC3339),
now.UTC().Format(time.RFC3339),
)
mgr.logCertificateDetails(name, leaf, now)
} else {
mgr.logger.Infof("certificate for domain %q ready in %s", name, elapsed.Round(time.Millisecond))
}
mgr.mu.RLock()
info := mgr.domains[d]
mgr.mu.RUnlock()
if info != nil && mgr.certNotifier != nil {
if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.serviceID, name); err != nil {
mgr.logger.Warnf("notify certificate ready for domain %q: %v", name, err)
}
}
}
func (mgr *Manager) setDomainState(d domain.Domain, state domainState, errMsg string) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
if info, ok := mgr.domains[d]; ok {
info.state = state
info.err = errMsg
}
}
// logCertificateDetails logs certificate validity and SCT timestamps.
func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate, now time.Time) {
if cert.NotBefore.After(now) {
mgr.logger.Warnf("certificate for %q NotBefore is in the future by %v", domain, cert.NotBefore.Sub(now))
}
sctTimestamps := mgr.parseSCTTimestamps(cert)
if len(sctTimestamps) == 0 {
return
}
for i, sctTime := range sctTimestamps {
if sctTime.After(now) {
mgr.logger.Warnf("certificate for %q SCT[%d] timestamp is in the future: %v (by %v)",
domain, i, sctTime.UTC(), sctTime.Sub(now))
} else {
mgr.logger.Debugf("certificate for %q SCT[%d] timestamp: %v (%v in the past)",
domain, i, sctTime.UTC(), now.Sub(sctTime))
}
}
}
// parseSCTTimestamps extracts SCT timestamps from a certificate.
func (mgr *Manager) parseSCTTimestamps(cert *x509.Certificate) []time.Time {
var timestamps []time.Time
for _, ext := range cert.Extensions {
if !ext.Id.Equal(oidSCTList) {
continue
}
// The extension value is an OCTET STRING containing the SCT list
var sctListBytes []byte
if _, err := asn1.Unmarshal(ext.Value, &sctListBytes); err != nil {
mgr.logger.Debugf("failed to unmarshal SCT list outer wrapper: %v", err)
continue
}
// SCT list format: 2-byte length prefix, then concatenated SCTs
if len(sctListBytes) < 2 {
continue
}
listLen := int(binary.BigEndian.Uint16(sctListBytes[:2]))
data := sctListBytes[2:]
if len(data) < listLen {
continue
}
// Parse individual SCTs
offset := 0
for offset < listLen {
if offset+2 > len(data) {
break
}
sctLen := int(binary.BigEndian.Uint16(data[offset : offset+2]))
offset += 2
if offset+sctLen > len(data) {
break
}
sctData := data[offset : offset+sctLen]
offset += sctLen
// SCT format: version (1) + log_id (32) + timestamp (8) + ...
if len(sctData) < 41 {
continue
}
// Timestamp is at offset 33 (after version + log_id), 8 bytes, milliseconds since epoch
tsMillis := binary.BigEndian.Uint64(sctData[33:41])
ts := time.UnixMilli(int64(tsMillis))
timestamps = append(timestamps, ts)
}
}
return timestamps
}
// dummyConn implements net.Conn to provide context for certificate fetching.
type dummyConn struct {
ctx context.Context
}
func (c *dummyConn) Read(b []byte) (n int, err error) { return 0, nil }
func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (c *dummyConn) Close() error { return nil }
func (c *dummyConn) LocalAddr() net.Addr { return nil }
func (c *dummyConn) RemoteAddr() net.Addr { return nil }
func (c *dummyConn) SetDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil }
func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil }
// RemoveDomain removes a domain from tracking.
func (mgr *Manager) RemoveDomain(d domain.Domain) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
delete(mgr.domains, d)
}
// PendingCerts returns the number of certificates currently being prefetched.
func (mgr *Manager) PendingCerts() int {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
var n int
for _, info := range mgr.domains {
if info.state == domainPending {
n++
}
}
return n
}
// TotalDomains returns the total number of registered domains.
func (mgr *Manager) TotalDomains() int {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
return len(mgr.domains)
}
// PendingDomains returns the domain names currently being prefetched.
func (mgr *Manager) PendingDomains() []string {
return mgr.domainsByState(domainPending)
}
// ReadyDomains returns domain names that have successfully obtained certificates.
func (mgr *Manager) ReadyDomains() []string {
return mgr.domainsByState(domainReady)
}
// FailedDomains returns domain names that failed certificate prefetch, mapped to their error.
func (mgr *Manager) FailedDomains() map[string]string {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
result := make(map[string]string)
for d, info := range mgr.domains {
if info.state == domainFailed {
result[d.PunycodeString()] = info.err
}
}
return result
}
func (mgr *Manager) domainsByState(state domainState) []string {
mgr.mu.RLock()
defer mgr.mu.RUnlock()
var domains []string
for d, info := range mgr.domains {
if info.state == state {
domains = append(domains, d.PunycodeString())
}
}
slices.Sort(domains)
return domains
}

View File

@@ -0,0 +1,102 @@
package acme
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHostPolicy(t *testing.T) {
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "")
mgr.AddDomain("example.com", "acc1", "rp1")
// Wait for the background prefetch goroutine to finish so the temp dir
// can be cleaned up without a race.
t.Cleanup(func() {
assert.Eventually(t, func() bool {
return mgr.PendingCerts() == 0
}, 30*time.Second, 50*time.Millisecond)
})
tests := []struct {
name string
host string
wantErr bool
}{
{
name: "exact domain match",
host: "example.com",
},
{
name: "domain with port",
host: "example.com:443",
},
{
name: "unknown domain",
host: "unknown.com",
wantErr: true,
},
{
name: "unknown domain with port",
host: "unknown.com:443",
wantErr: true,
},
{
name: "empty host",
host: "",
wantErr: true,
},
{
name: "port only",
host: ":443",
wantErr: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := mgr.hostPolicy(context.Background(), tc.host)
if tc.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown domain")
} else {
assert.NoError(t, err)
}
})
}
}
func TestDomainStates(t *testing.T) {
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "")
assert.Equal(t, 0, mgr.PendingCerts(), "initially zero")
assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains")
assert.Empty(t, mgr.PendingDomains())
assert.Empty(t, mgr.ReadyDomains())
assert.Empty(t, mgr.FailedDomains())
// AddDomain starts as pending, then the prefetch goroutine will fail
// (no real ACME server) and transition to failed.
mgr.AddDomain("a.example.com", "acc1", "rp1")
mgr.AddDomain("b.example.com", "acc1", "rp1")
assert.Equal(t, 2, mgr.TotalDomains(), "two domains registered")
// Pending domains should eventually drain after prefetch goroutines finish.
assert.Eventually(t, func() bool {
return mgr.PendingCerts() == 0
}, 30*time.Second, 100*time.Millisecond, "pending certs should return to zero after prefetch completes")
assert.Empty(t, mgr.PendingDomains())
assert.Equal(t, 2, mgr.TotalDomains(), "total domains unchanged")
// With a fake ACME URL, both should have failed.
failed := mgr.FailedDomains()
assert.Len(t, failed, 2, "both domains should have failed")
assert.Contains(t, failed, "a.example.com")
assert.Contains(t, failed, "b.example.com")
assert.Empty(t, mgr.ReadyDomains())
}

View File

@@ -0,0 +1,18 @@
<!doctype html>
{{ range $method, $value := .Methods }}
{{ if eq $method "pin" }}
<form>
<label for={{ $value }}>PIN:</label>
<input name={{ $value }} id={{ $value }} />
<button type=submit>Submit</button>
</form>
{{ else if eq $method "password" }}
<form>
<label for={{ $value }}>Password:</label>
<input name={{ $value }} id={{ $value }}/>
<button type=submit>Submit</button>
</form>
{{ else if eq $method "oidc" }}
<a href={{ $value }}>Click here to log in with SSO</a>
{{ end }}
{{ end }}

View File

@@ -0,0 +1,364 @@
package auth
import (
"context"
"crypto/ed25519"
"encoding/base64"
"fmt"
"net"
"net/http"
"net/url"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/proxy/web"
"github.com/netbirdio/netbird/shared/management/proto"
)
type authenticator interface {
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
}
// SessionValidator validates session tokens and checks user access permissions.
type SessionValidator interface {
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
}
// Scheme defines an authentication mechanism for a domain.
type Scheme interface {
Type() auth.Method
// Authenticate checks the request and determines whether it represents
// an authenticated user. An empty token indicates an unauthenticated
// request; optionally, promptData may be returned for the login UI.
// An error indicates an infrastructure failure (e.g. gRPC unavailable).
Authenticate(*http.Request) (token string, promptData string, err error)
}
type DomainConfig struct {
Schemes []Scheme
SessionPublicKey ed25519.PublicKey
SessionExpiration time.Duration
AccountID string
ServiceID string
}
type validationResult struct {
UserID string
Valid bool
DeniedReason string
}
type Middleware struct {
domainsMux sync.RWMutex
domains map[string]DomainConfig
logger *log.Logger
sessionValidator SessionValidator
}
// NewMiddleware creates a new authentication middleware.
// The sessionValidator is optional; if nil, OIDC session tokens will be validated
// locally without group access checks.
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware {
if logger == nil {
logger = log.StandardLogger()
}
return &Middleware{
domains: make(map[string]DomainConfig),
logger: logger,
sessionValidator: sessionValidator,
}
}
// Protect applies authentication middleware to the passed handler.
// For each incoming request it will be checked against the middleware's
// internal list of protected domains.
// If the Host domain in the inbound request is not present, then it will
// simply be passed through.
// However, if the Host domain is present, then the specified authentication
// schemes for that domain will be applied to the request.
// In the event that no authentication schemes are defined for the domain,
// then the request will also be simply passed through.
func (mw *Middleware) Protect(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
config, exists := mw.getDomainConfig(host)
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
if !exists || len(config.Schemes) == 0 {
next.ServeHTTP(w, r)
return
}
// Set account and service IDs in captured data for access logging.
setCapturedIDs(r, config)
if mw.handleOAuthCallbackError(w, r) {
return
}
if mw.forwardWithSessionCookie(w, r, host, config, next) {
return
}
mw.authenticateWithSchemes(w, r, host, config)
})
}
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
mw.domainsMux.RLock()
defer mw.domainsMux.RUnlock()
config, exists := mw.domains[host]
return config, exists
}
func setCapturedIDs(r *http.Request, config DomainConfig) {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetAccountId(types.AccountID(config.AccountID))
cd.SetServiceId(config.ServiceID)
}
}
// handleOAuthCallbackError checks for error query parameters from an OAuth
// callback and renders the access denied page if present.
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
errCode := r.URL.Query().Get("error")
if errCode == "" {
return false
}
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(auth.MethodOIDC.String())
requestID = cd.GetRequestID()
}
errDesc := r.URL.Query().Get("error_description")
if errDesc == "" {
errDesc = "An error occurred during authentication"
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
return true
}
// forwardWithSessionCookie checks for a valid session cookie and, if found,
// sets the user identity on the request context and forwards to the next handler.
func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
cookie, err := r.Cookie(auth.SessionCookieName)
if err != nil {
return false
}
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
if err != nil {
return false
}
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetUserID(userID)
cd.SetAuthMethod(method)
}
next.ServeHTTP(w, r)
return true
}
// authenticateWithSchemes tries each configured auth scheme in order.
// On success it sets a session cookie and redirects; on failure it renders the login page.
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
methods := make(map[string]string)
var attemptedMethod string
for _, scheme := range config.Schemes {
token, promptData, err := scheme.Authenticate(r)
if err != nil {
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
}
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
return
}
// Track if credentials were submitted but auth failed
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
attemptedMethod = scheme.Type().String()
}
if token != "" {
mw.handleAuthenticatedToken(w, r, host, token, config, scheme)
return
}
methods[scheme.Type().String()] = promptData
}
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
if attemptedMethod != "" {
cd.SetAuthMethod(attemptedMethod)
}
}
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
}
// handleAuthenticatedToken validates the token, handles denied access, and on
// success sets a session cookie and redirects to the original URL.
func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Request, host, token string, config DomainConfig, scheme Scheme) {
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
if err != nil {
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetAuthMethod(scheme.Type().String())
}
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if !result.Valid {
var requestID string
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
requestID = cd.GetRequestID()
}
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
return
}
expiration := config.SessionExpiration
if expiration == 0 {
expiration = auth.DefaultSessionExpiry
}
http.SetCookie(w, &http.Cookie{
Name: auth.SessionCookieName,
Value: token,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: int(expiration.Seconds()),
})
// Redirect instead of forwarding the auth POST to the backend.
// The browser will follow with a GET carrying the new session cookie.
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(proxy.OriginAuth)
cd.SetUserID(result.UserID)
cd.SetAuthMethod(scheme.Type().String())
}
redirectURL := stripSessionTokenParam(r.URL)
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
}
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.
func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
switch method {
case auth.MethodPIN:
return r.FormValue("pin") != ""
case auth.MethodPassword:
return r.FormValue("password") != ""
case auth.MethodOIDC:
return r.URL.Query().Get("session_token") != ""
}
return false
}
// AddDomain registers authentication schemes for the given domain.
// If schemes are provided, a valid session public key is required to sign/verify
// session JWTs. Returns an error if the key is missing or invalid.
// Callers must not serve the domain if this returns an error, to avoid
// exposing an unauthenticated service.
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID, serviceID string) error {
if len(schemes) == 0 {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
mw.domains[domain] = DomainConfig{
AccountID: accountID,
ServiceID: serviceID,
}
return nil
}
pubKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
if err != nil {
return fmt.Errorf("decode session public key for domain %s: %w", domain, err)
}
if len(pubKeyBytes) != ed25519.PublicKeySize {
return fmt.Errorf("invalid session public key size for domain %s: got %d, want %d", domain, len(pubKeyBytes), ed25519.PublicKeySize)
}
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
mw.domains[domain] = DomainConfig{
Schemes: schemes,
SessionPublicKey: pubKeyBytes,
SessionExpiration: expiration,
AccountID: accountID,
ServiceID: serviceID,
}
return nil
}
func (mw *Middleware) RemoveDomain(domain string) {
mw.domainsMux.Lock()
defer mw.domainsMux.Unlock()
delete(mw.domains, domain)
}
// validateSessionToken validates a session token, optionally checking group access via gRPC.
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access.
// For other auth methods (PIN, password), it validates the JWT locally.
// Returns a validationResult with user ID and validity status, or error for invalid tokens.
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) {
// For OIDC with a session validator, call the gRPC service to check group access
if method == auth.MethodOIDC && mw.sessionValidator != nil {
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
Domain: host,
SessionToken: token,
})
if err != nil {
mw.logger.WithError(err).Error("ValidateSession gRPC call failed")
return nil, fmt.Errorf("session validation failed")
}
if !resp.Valid {
mw.logger.WithFields(log.Fields{
"domain": host,
"denied_reason": resp.DeniedReason,
"user_id": resp.UserId,
}).Debug("Session validation denied")
return &validationResult{
UserID: resp.UserId,
Valid: false,
DeniedReason: resp.DeniedReason,
}, nil
}
return &validationResult{UserID: resp.UserId, Valid: true}, nil
}
// For non-OIDC methods or when no validator is configured, validate JWT locally
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
if err != nil {
return nil, err
}
return &validationResult{UserID: userID, Valid: true}, nil
}
// stripSessionTokenParam returns the request URI with the session_token query
// parameter removed so it doesn't linger in the browser's address bar or history.
func stripSessionTokenParam(u *url.URL) string {
q := u.Query()
if !q.Has("session_token") {
return u.RequestURI()
}
q.Del("session_token")
clean := *u
clean.RawQuery = q.Encode()
return clean.RequestURI()
}

View File

@@ -0,0 +1,660 @@
package auth
import (
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
)
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
t.Helper()
kp, err := sessionkey.GenerateKeyPair()
require.NoError(t, err)
return kp
}
// stubScheme is a minimal Scheme implementation for testing.
type stubScheme struct {
method auth.Method
token string
promptID string
authFn func(*http.Request) (string, string, error)
}
func (s *stubScheme) Type() auth.Method { return s.method }
func (s *stubScheme) Authenticate(r *http.Request) (string, string, error) {
if s.authFn != nil {
return s.authFn(r)
}
return s.token, s.promptID, nil
}
func newPassthroughHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("backend"))
})
}
func TestAddDomain_ValidKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")
require.NoError(t, err)
mw.domainsMux.RLock()
config, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.True(t, exists, "domain should be registered")
assert.Len(t, config.Schemes, 1)
assert.Equal(t, ed25519.PublicKeySize, len(config.SessionPublicKey))
assert.Equal(t, time.Hour, config.SessionExpiration)
}
func TestAddDomain_EmptyKey(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
mw.domainsMux.RLock()
_, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.False(t, exists, "domain must not be registered with an empty session key")
}
func TestAddDomain_InvalidBase64(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "")
require.Error(t, err)
assert.Contains(t, err.Error(), "decode session public key")
mw.domainsMux.RLock()
_, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.False(t, exists, "domain must not be registered with invalid base64 key")
}
func TestAddDomain_WrongKeySize(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid session public key size")
mw.domainsMux.RLock()
_, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.False(t, exists, "domain must not be registered with a wrong-size key")
}
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "")
require.NoError(t, err, "domains with no auth schemes should not require a key")
mw.domainsMux.RLock()
_, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.True(t, exists)
}
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", ""))
mw.domainsMux.RLock()
config := mw.domains["example.com"]
mw.domainsMux.RUnlock()
pubKeyBytes, _ := base64.StdEncoding.DecodeString(kp2.PublicKey)
assert.Equal(t, ed25519.PublicKey(pubKeyBytes), config.SessionPublicKey, "should use the latest key")
assert.Equal(t, 2*time.Hour, config.SessionExpiration)
}
func TestRemoveDomain(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
mw.RemoveDomain("example.com")
mw.domainsMux.RLock()
_, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.False(t, exists)
}
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "backend", rec.Body.String())
}
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", ""))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "backend", rec.Body.String())
}
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "unauthenticated request should not reach backend")
}
func TestProtect_HostWithPortIsMatched(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
req := httptest.NewRequest(http.MethodGet, "http://example.com:8443/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "host with port should still match the protected domain")
}
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
capturedData := &proxy.CapturedData{}
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cd := proxy.CapturedDataFromContext(r.Context())
require.NotNil(t, cd)
assert.Equal(t, "test-user", cd.GetUserID())
assert.Equal(t, "pin", cd.GetAuthMethod())
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("authenticated"))
}))
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "authenticated", rec.Body.String())
}
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
// Sign a token that expired 1 second ago.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
require.NoError(t, err)
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "expired session should not reach the backend")
}
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
// Token signed for a different domain audience.
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "cookie for wrong domain should be rejected")
}
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp1 := generateTestKeyPair(t)
kp2 := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
// Token signed with a different private key.
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "cookie signed by wrong key should be rejected")
}
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
require.NoError(t, err)
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(r *http.Request) (string, string, error) {
if r.FormValue("pin") == "111111" {
return token, "", nil
}
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
// Submit the PIN via form POST.
form := url.Values{"pin": {"111111"}}
req := httptest.NewRequest(http.MethodPost, "http://example.com/somepath", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "backend should not be called during auth, only a redirect should be returned")
assert.Equal(t, http.StatusSeeOther, rec.Code)
assert.Equal(t, "/somepath", rec.Header().Get("Location"), "redirect should point to the original request URI")
cookies := rec.Result().Cookies()
var sessionCookie *http.Cookie
for _, c := range cookies {
if c.Name == auth.SessionCookieName {
sessionCookie = c
break
}
}
require.NotNil(t, sessionCookie, "session cookie should be set after successful auth")
assert.True(t, sessionCookie.HttpOnly)
assert.True(t, sessionCookie.Secure)
assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite)
}
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
for _, c := range rec.Result().Cookies() {
assert.NotEqual(t, auth.SessionCookieName, c.Name, "no session cookie should be set on failed auth")
}
}
func TestProtect_MultipleSchemes(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
require.NoError(t, err)
// First scheme (PIN) always fails, second scheme (password) succeeds.
pinScheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
passwordScheme := &stubScheme{
method: auth.MethodPassword,
authFn: func(r *http.Request) (string, string, error) {
if r.FormValue("password") == "secret" {
return token, "", nil
}
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", ""))
var backendCalled bool
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
backendCalled = true
w.WriteHeader(http.StatusOK)
})
handler := mw.Protect(backend)
form := url.Values{"password": {"secret"}}
req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.False(t, backendCalled, "backend should not be called during auth")
assert.Equal(t, http.StatusSeeOther, rec.Code)
}
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
// Return a garbage token that won't validate.
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string, error) {
return "invalid-jwt-token", "", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
handler := mw.Protect(newPassthroughHandler())
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusBadRequest, rec.Code)
}
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
// 32 random bytes that happen to be valid base64 and correct size
// but are actually a valid ed25519 public key length-wise.
// This should succeed because ed25519 public keys are just 32 bytes.
randomBytes := make([]byte, ed25519.PublicKeySize)
_, err := rand.Read(randomBytes)
require.NoError(t, err)
key := base64.StdEncoding.EncodeToString(randomBytes)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "")
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
}
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
// Attempt to overwrite with an invalid key.
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "")
require.Error(t, err)
// The original valid config should still be intact.
mw.domainsMux.RLock()
config, exists := mw.domains["example.com"]
mw.domainsMux.RUnlock()
assert.True(t, exists, "original config should still exist")
assert.Len(t, config.Schemes, 1)
assert.Equal(t, time.Hour, config.SessionExpiration)
}
func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
// Scheme that always fails authentication (returns empty token)
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
capturedData := &proxy.CapturedData{}
handler := mw.Protect(newPassthroughHandler())
// Submit wrong PIN - should capture auth method
form := url.Values{"pin": {"wrong-pin"}}
req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Equal(t, "pin", capturedData.GetAuthMethod(), "Auth method should be captured for failed PIN auth")
}
func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodPassword,
authFn: func(_ *http.Request) (string, string, error) {
return "", "password", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
capturedData := &proxy.CapturedData{}
handler := mw.Protect(newPassthroughHandler())
// Submit wrong password - should capture auth method
form := url.Values{"password": {"wrong-password"}}
req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Equal(t, "password", capturedData.GetAuthMethod(), "Auth method should be captured for failed password auth")
}
func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
mw := NewMiddleware(log.StandardLogger(), nil)
kp := generateTestKeyPair(t)
scheme := &stubScheme{
method: auth.MethodPIN,
authFn: func(_ *http.Request) (string, string, error) {
return "", "pin", nil
},
}
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
capturedData := &proxy.CapturedData{}
handler := mw.Protect(newPassthroughHandler())
// No credentials submitted - should not capture auth method
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assert.Empty(t, capturedData.GetAuthMethod(), "Auth method should not be captured when no credentials submitted")
}
func TestWasCredentialSubmitted(t *testing.T) {
tests := []struct {
name string
method auth.Method
formData url.Values
query url.Values
expected bool
}{
{
name: "PIN submitted",
method: auth.MethodPIN,
formData: url.Values{"pin": {"123456"}},
expected: true,
},
{
name: "PIN not submitted",
method: auth.MethodPIN,
formData: url.Values{},
expected: false,
},
{
name: "Password submitted",
method: auth.MethodPassword,
formData: url.Values{"password": {"secret"}},
expected: true,
},
{
name: "Password not submitted",
method: auth.MethodPassword,
formData: url.Values{},
expected: false,
},
{
name: "OIDC token in query",
method: auth.MethodOIDC,
query: url.Values{"session_token": {"abc123"}},
expected: true,
},
{
name: "OIDC token not in query",
method: auth.MethodOIDC,
query: url.Values{},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reqURL := "http://example.com/"
if len(tt.query) > 0 {
reqURL += "?" + tt.query.Encode()
}
var body *strings.Reader
if len(tt.formData) > 0 {
body = strings.NewReader(tt.formData.Encode())
} else {
body = strings.NewReader("")
}
req := httptest.NewRequest(http.MethodPost, reqURL, body)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
result := wasCredentialSubmitted(req, tt.method)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -0,0 +1,65 @@
package auth
import (
"context"
"fmt"
"net/http"
"net/url"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
type urlGenerator interface {
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
}
type OIDC struct {
id string
accountId string
forwardedProto string
client urlGenerator
}
// NewOIDC creates a new OIDC authentication scheme
func NewOIDC(client urlGenerator, id, accountId, forwardedProto string) OIDC {
return OIDC{
id: id,
accountId: accountId,
forwardedProto: forwardedProto,
client: client,
}
}
func (OIDC) Type() auth.Method {
return auth.MethodOIDC
}
// Authenticate checks for an OIDC session token or obtains the OIDC redirect URL.
func (o OIDC) Authenticate(r *http.Request) (string, string, error) {
// Check for the session_token query param (from OIDC redirects).
// The management server passes the token in the URL because it cannot set
// cookies for the proxy's domain (cookies are domain-scoped per RFC 6265).
if token := r.URL.Query().Get("session_token"); token != "" {
return token, "", nil
}
redirectURL := &url.URL{
Scheme: auth.ResolveProto(o.forwardedProto, r.TLS),
Host: r.Host,
Path: r.URL.Path,
}
res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{
Id: o.id,
AccountId: o.accountId,
RedirectUrl: redirectURL.String(),
})
if err != nil {
return "", "", fmt.Errorf("get OIDC URL: %w", err)
}
return "", res.GetUrl(), nil
}

View File

@@ -0,0 +1,61 @@
package auth
import (
"fmt"
"net/http"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
const passwordFormId = "password"
type Password struct {
id, accountId string
client authenticator
}
func NewPassword(client authenticator, id, accountId string) Password {
return Password{
id: id,
accountId: accountId,
client: client,
}
}
func (Password) Type() auth.Method {
return auth.MethodPassword
}
// Authenticate attempts to authenticate the request using a form
// value passed in the request.
// If authentication fails, the required HTTP form ID is returned
// so that it can be injected into a request from the UI so that
// authentication may be successful.
func (p Password) Authenticate(r *http.Request) (string, string, error) {
password := r.FormValue(passwordFormId)
if password == "" {
// No password submitted; return the form ID so the UI can prompt the user.
return "", passwordFormId, nil
}
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
Id: p.id,
AccountId: p.accountId,
Request: &proto.AuthenticateRequest_Password{
Password: &proto.PasswordRequest{
Password: password,
},
},
})
if err != nil {
return "", "", fmt.Errorf("authenticate password: %w", err)
}
if res.GetSuccess() {
return res.GetSessionToken(), "", nil
}
return "", passwordFormId, nil
}

View File

@@ -0,0 +1,61 @@
package auth
import (
"fmt"
"net/http"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
const pinFormId = "pin"
type Pin struct {
id, accountId string
client authenticator
}
func NewPin(client authenticator, id, accountId string) Pin {
return Pin{
id: id,
accountId: accountId,
client: client,
}
}
func (Pin) Type() auth.Method {
return auth.MethodPIN
}
// Authenticate attempts to authenticate the request using a form
// value passed in the request.
// If authentication fails, the required HTTP form ID is returned
// so that it can be injected into a request from the UI so that
// authentication may be successful.
func (p Pin) Authenticate(r *http.Request) (string, string, error) {
pin := r.FormValue(pinFormId)
if pin == "" {
// No PIN submitted; return the form ID so the UI can prompt the user.
return "", pinFormId, nil
}
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
Id: p.id,
AccountId: p.accountId,
Request: &proto.AuthenticateRequest_Pin{
Pin: &proto.PinRequest{
Pin: pin,
},
},
})
if err != nil {
return "", "", fmt.Errorf("authenticate pin: %w", err)
}
if res.GetSuccess() {
return res.GetSessionToken(), "", nil
}
return "", pinFormId, nil
}

View File

@@ -0,0 +1,279 @@
// Package certwatch watches TLS certificate files on disk and provides
// a hot-reloading GetCertificate callback for tls.Config.
package certwatch
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"path/filepath"
"sync"
"time"
"github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus"
)
const (
defaultPollInterval = 30 * time.Second
debounceDelay = 500 * time.Millisecond
)
// Watcher monitors TLS certificate files on disk and caches the loaded
// certificate in memory. It detects changes via fsnotify (with a polling
// fallback for filesystems like NFS that lack inotify support) and
// reloads the certificate pair automatically.
type Watcher struct {
certPath string
keyPath string
mu sync.RWMutex
cert *tls.Certificate
leaf *x509.Certificate
pollInterval time.Duration
logger *log.Logger
}
// NewWatcher creates a Watcher that monitors the given cert and key files.
// It performs an initial load of the certificate and returns an error
// if the initial load fails.
func NewWatcher(certPath, keyPath string, logger *log.Logger) (*Watcher, error) {
if logger == nil {
logger = log.StandardLogger()
}
w := &Watcher{
certPath: certPath,
keyPath: keyPath,
pollInterval: defaultPollInterval,
logger: logger,
}
if err := w.reload(); err != nil {
return nil, fmt.Errorf("initial certificate load: %w", err)
}
return w, nil
}
// GetCertificate returns the current in-memory certificate.
// It is safe for concurrent use and compatible with tls.Config.GetCertificate.
func (w *Watcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
w.mu.RLock()
defer w.mu.RUnlock()
return w.cert, nil
}
// Watch starts watching for certificate file changes. It blocks until
// ctx is cancelled. It uses fsnotify for immediate detection and falls
// back to polling if fsnotify is unavailable (e.g. on NFS).
// Even with fsnotify active, a periodic poll runs as a safety net.
func (w *Watcher) Watch(ctx context.Context) {
// Watch the parent directory rather than individual files. Some volume
// mounts use an atomic symlink swap (..data -> timestamped dir), so
// watching the parent directory catches the link replacement.
certDir := filepath.Dir(w.certPath)
keyDir := filepath.Dir(w.keyPath)
watcher, err := fsnotify.NewWatcher()
if err != nil {
w.logger.Warnf("fsnotify unavailable, using polling only: %v", err)
w.pollLoop(ctx)
return
}
defer func() {
if err := watcher.Close(); err != nil {
w.logger.Debugf("close fsnotify watcher: %v", err)
}
}()
if err := watcher.Add(certDir); err != nil {
w.logger.Warnf("fsnotify watch on %s failed, using polling only: %v", certDir, err)
w.pollLoop(ctx)
return
}
if keyDir != certDir {
if err := watcher.Add(keyDir); err != nil {
w.logger.Warnf("fsnotify watch on %s failed: %v", keyDir, err)
}
}
w.logger.Infof("watching certificate files in %s", certDir)
w.fsnotifyLoop(ctx, watcher)
}
func (w *Watcher) fsnotifyLoop(ctx context.Context, watcher *fsnotify.Watcher) {
certBase := filepath.Base(w.certPath)
keyBase := filepath.Base(w.keyPath)
var debounce *time.Timer
defer func() {
if debounce != nil {
debounce.Stop()
}
}()
// Periodic poll as a safety net for missed fsnotify events.
pollTicker := time.NewTicker(w.pollInterval)
defer pollTicker.Stop()
for {
select {
case <-ctx.Done():
return
case event, ok := <-watcher.Events:
if !ok {
return
}
base := filepath.Base(event.Name)
if !isRelevantFile(base, certBase, keyBase) {
w.logger.Debugf("fsnotify: ignoring event %s on %s", event.Op, event.Name)
continue
}
if !event.Has(fsnotify.Create) && !event.Has(fsnotify.Write) && !event.Has(fsnotify.Rename) {
w.logger.Debugf("fsnotify: ignoring op %s on %s", event.Op, base)
continue
}
w.logger.Debugf("fsnotify: detected %s on %s, scheduling reload", event.Op, base)
// Debounce: cert-manager may write cert and key as separate
// operations. Wait briefly to load both at once.
if debounce != nil {
debounce.Stop()
}
debounce = time.AfterFunc(debounceDelay, func() {
if ctx.Err() != nil {
return
}
w.tryReload()
})
case err, ok := <-watcher.Errors:
if !ok {
return
}
w.logger.Warnf("fsnotify error: %v", err)
case <-pollTicker.C:
w.tryReload()
}
}
}
func (w *Watcher) pollLoop(ctx context.Context) {
ticker := time.NewTicker(w.pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
w.tryReload()
}
}
}
// reload loads the certificate from disk and updates the in-memory cache.
func (w *Watcher) reload() error {
cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath)
if err != nil {
return err
}
// Parse the leaf for comparison on subsequent reloads.
if cert.Leaf == nil && len(cert.Certificate) > 0 {
leaf, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return fmt.Errorf("parse leaf certificate: %w", err)
}
cert.Leaf = leaf
}
w.mu.Lock()
w.cert = &cert
w.leaf = cert.Leaf
w.mu.Unlock()
w.logCertDetails("loaded certificate", cert.Leaf)
return nil
}
// tryReload attempts to reload the certificate. It skips the update
// if the certificate on disk is identical to the one in memory (same
// serial number and issuer) to avoid redundant log noise.
func (w *Watcher) tryReload() {
cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath)
if err != nil {
w.logger.Warnf("reload certificate: %v", err)
return
}
if cert.Leaf == nil && len(cert.Certificate) > 0 {
leaf, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
w.logger.Warnf("parse reloaded leaf certificate: %v", err)
return
}
cert.Leaf = leaf
}
w.mu.Lock()
if w.leaf != nil && cert.Leaf != nil &&
w.leaf.SerialNumber.Cmp(cert.Leaf.SerialNumber) == 0 &&
w.leaf.Issuer.CommonName == cert.Leaf.Issuer.CommonName {
w.mu.Unlock()
return
}
prev := w.leaf
w.cert = &cert
w.leaf = cert.Leaf
w.mu.Unlock()
w.logCertChange(prev, cert.Leaf)
}
func (w *Watcher) logCertDetails(msg string, leaf *x509.Certificate) {
if leaf == nil {
w.logger.Info(msg)
return
}
w.logger.Infof("%s: subject=%q serial=%s SANs=%v notAfter=%s",
msg,
leaf.Subject.CommonName,
leaf.SerialNumber.Text(16),
leaf.DNSNames,
leaf.NotAfter.UTC().Format(time.RFC3339),
)
}
func (w *Watcher) logCertChange(prev, next *x509.Certificate) {
if prev == nil || next == nil {
w.logCertDetails("certificate reloaded from disk", next)
return
}
w.logger.Infof("certificate reloaded from disk: subject=%q -> %q serial=%s -> %s notAfter=%s -> %s",
prev.Subject.CommonName, next.Subject.CommonName,
prev.SerialNumber.Text(16), next.SerialNumber.Text(16),
prev.NotAfter.UTC().Format(time.RFC3339), next.NotAfter.UTC().Format(time.RFC3339),
)
}
// isRelevantFile returns true if the changed file name is one we care about.
// This includes the cert/key files themselves and the ..data symlink used
// by atomic volume mounts.
func isRelevantFile(changed, certBase, keyBase string) bool {
return changed == certBase || changed == keyBase || changed == "..data"
}

View File

@@ -0,0 +1,292 @@
package certwatch
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func generateSelfSignedCert(t *testing.T, serial int64) (certPEM, keyPEM []byte) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(serial),
Subject: pkix.Name{CommonName: "test"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
keyDER, err := x509.MarshalECPrivateKey(key)
require.NoError(t, err)
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return certPEM, keyPEM
}
func writeCert(t *testing.T, dir string, certPEM, keyPEM []byte) {
t.Helper()
require.NoError(t, os.WriteFile(filepath.Join(dir, "tls.crt"), certPEM, 0o600))
require.NoError(t, os.WriteFile(filepath.Join(dir, "tls.key"), keyPEM, 0o600))
}
func TestNewWatcher(t *testing.T) {
dir := t.TempDir()
certPEM, keyPEM := generateSelfSignedCert(t, 1)
writeCert(t, dir, certPEM, keyPEM)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
cert, err := w.GetCertificate(nil)
require.NoError(t, err)
require.NotNil(t, cert)
assert.Equal(t, int64(1), cert.Leaf.SerialNumber.Int64())
}
func TestNewWatcherMissingFiles(t *testing.T) {
dir := t.TempDir()
_, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
assert.Error(t, err)
}
func TestReload(t *testing.T) {
dir := t.TempDir()
certPEM1, keyPEM1 := generateSelfSignedCert(t, 100)
writeCert(t, dir, certPEM1, keyPEM1)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
cert1, err := w.GetCertificate(nil)
require.NoError(t, err)
assert.Equal(t, int64(100), cert1.Leaf.SerialNumber.Int64())
// Write a new cert with a different serial.
certPEM2, keyPEM2 := generateSelfSignedCert(t, 200)
writeCert(t, dir, certPEM2, keyPEM2)
// Manually trigger reload.
w.tryReload()
cert2, err := w.GetCertificate(nil)
require.NoError(t, err)
assert.Equal(t, int64(200), cert2.Leaf.SerialNumber.Int64())
}
func TestTryReloadSkipsUnchanged(t *testing.T) {
dir := t.TempDir()
certPEM, keyPEM := generateSelfSignedCert(t, 42)
writeCert(t, dir, certPEM, keyPEM)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
cert1, err := w.GetCertificate(nil)
require.NoError(t, err)
// Reload with same cert - pointer should remain the same.
w.tryReload()
cert2, err := w.GetCertificate(nil)
require.NoError(t, err)
assert.Same(t, cert1, cert2, "cert pointer should not change when content is the same")
}
func TestWatchDetectsChanges(t *testing.T) {
dir := t.TempDir()
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
writeCert(t, dir, certPEM1, keyPEM1)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
// Use a short poll interval for the test.
w.pollInterval = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go w.Watch(ctx)
// Write new cert.
certPEM2, keyPEM2 := generateSelfSignedCert(t, 999)
writeCert(t, dir, certPEM2, keyPEM2)
// Wait for the watcher to pick it up.
require.Eventually(t, func() bool {
cert, err := w.GetCertificate(nil)
if err != nil {
return false
}
return cert.Leaf.SerialNumber.Int64() == 999
}, 5*time.Second, 50*time.Millisecond, "watcher should detect cert change")
}
func TestIsRelevantFile(t *testing.T) {
assert.True(t, isRelevantFile("tls.crt", "tls.crt", "tls.key"))
assert.True(t, isRelevantFile("tls.key", "tls.crt", "tls.key"))
assert.True(t, isRelevantFile("..data", "tls.crt", "tls.key"))
assert.False(t, isRelevantFile("other.txt", "tls.crt", "tls.key"))
}
// TestWatchSymlinkRotation simulates Kubernetes secret volume updates where
// the data directory is atomically swapped via a ..data symlink.
func TestWatchSymlinkRotation(t *testing.T) {
base := t.TempDir()
// Create initial target directory with certs.
dir1 := filepath.Join(base, "dir1")
require.NoError(t, os.Mkdir(dir1, 0o755))
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
require.NoError(t, os.WriteFile(filepath.Join(dir1, "tls.crt"), certPEM1, 0o600))
require.NoError(t, os.WriteFile(filepath.Join(dir1, "tls.key"), keyPEM1, 0o600))
// Create ..data symlink pointing to dir1.
dataLink := filepath.Join(base, "..data")
require.NoError(t, os.Symlink(dir1, dataLink))
// Create tls.crt and tls.key as symlinks to ..data/{file}.
certLink := filepath.Join(base, "tls.crt")
keyLink := filepath.Join(base, "tls.key")
require.NoError(t, os.Symlink(filepath.Join(dataLink, "tls.crt"), certLink))
require.NoError(t, os.Symlink(filepath.Join(dataLink, "tls.key"), keyLink))
w, err := NewWatcher(certLink, keyLink, nil)
require.NoError(t, err)
cert, err := w.GetCertificate(nil)
require.NoError(t, err)
assert.Equal(t, int64(1), cert.Leaf.SerialNumber.Int64())
w.pollInterval = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go w.Watch(ctx)
// Simulate k8s atomic rotation: create dir2, swap ..data symlink.
dir2 := filepath.Join(base, "dir2")
require.NoError(t, os.Mkdir(dir2, 0o755))
certPEM2, keyPEM2 := generateSelfSignedCert(t, 777)
require.NoError(t, os.WriteFile(filepath.Join(dir2, "tls.crt"), certPEM2, 0o600))
require.NoError(t, os.WriteFile(filepath.Join(dir2, "tls.key"), keyPEM2, 0o600))
// Atomic swap: create temp link, then rename over ..data.
tmpLink := filepath.Join(base, "..data_tmp")
require.NoError(t, os.Symlink(dir2, tmpLink))
require.NoError(t, os.Rename(tmpLink, dataLink))
require.Eventually(t, func() bool {
cert, err := w.GetCertificate(nil)
if err != nil {
return false
}
return cert.Leaf.SerialNumber.Int64() == 777
}, 5*time.Second, 50*time.Millisecond, "watcher should detect symlink rotation")
}
// TestPollLoopDetectsChanges verifies the poll-only fallback path works.
func TestPollLoopDetectsChanges(t *testing.T) {
dir := t.TempDir()
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
writeCert(t, dir, certPEM1, keyPEM1)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
w.pollInterval = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Directly use pollLoop to test the fallback path.
go w.pollLoop(ctx)
certPEM2, keyPEM2 := generateSelfSignedCert(t, 555)
writeCert(t, dir, certPEM2, keyPEM2)
require.Eventually(t, func() bool {
cert, err := w.GetCertificate(nil)
if err != nil {
return false
}
return cert.Leaf.SerialNumber.Int64() == 555
}, 5*time.Second, 50*time.Millisecond, "poll loop should detect cert change")
}
func TestGetCertificateConcurrency(t *testing.T) {
dir := t.TempDir()
certPEM, keyPEM := generateSelfSignedCert(t, 1)
writeCert(t, dir, certPEM, keyPEM)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
// Hammer GetCertificate concurrently while reloading.
done := make(chan struct{})
go func() {
for i := 0; i < 100; i++ {
w.tryReload()
}
close(done)
}()
for i := 0; i < 1000; i++ {
cert, err := w.GetCertificate(&tls.ClientHelloInfo{})
assert.NoError(t, err)
assert.NotNil(t, cert)
}
<-done
}

View File

@@ -0,0 +1,388 @@
// Package debug provides HTTP debug endpoints and CLI client for the proxy server.
package debug
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// StatusFilters contains filter options for status queries.
type StatusFilters struct {
IPs []string
Names []string
Status string
ConnectionType string
}
// Client provides CLI access to debug endpoints.
type Client struct {
baseURL string
jsonOutput bool
httpClient *http.Client
out io.Writer
}
// NewClient creates a new debug client.
func NewClient(baseURL string, jsonOutput bool, out io.Writer) *Client {
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
baseURL = "http://" + baseURL
}
baseURL = strings.TrimSuffix(baseURL, "/")
return &Client{
baseURL: baseURL,
jsonOutput: jsonOutput,
out: out,
httpClient: &http.Client{
Timeout: 30 * time.Second,
},
}
}
// Health fetches the health status.
func (c *Client) Health(ctx context.Context) error {
return c.fetchAndPrint(ctx, "/debug/health", c.printHealth)
}
func (c *Client) printHealth(data map[string]any) {
_, _ = fmt.Fprintf(c.out, "Status: %v\n", data["status"])
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
_, _ = fmt.Fprintf(c.out, "Management Connected: %s\n", boolIcon(data["management_connected"]))
_, _ = fmt.Fprintf(c.out, "All Clients Healthy: %s\n", boolIcon(data["all_clients_healthy"]))
total, _ := data["certs_total"].(float64)
ready, _ := data["certs_ready"].(float64)
pending, _ := data["certs_pending"].(float64)
failed, _ := data["certs_failed"].(float64)
if total > 0 {
_, _ = fmt.Fprintf(c.out, "Certificates: %d ready, %d pending, %d failed (%d total)\n",
int(ready), int(pending), int(failed), int(total))
}
if domains, ok := data["certs_ready_domains"].([]any); ok && len(domains) > 0 {
_, _ = fmt.Fprintf(c.out, " Ready:\n")
for _, d := range domains {
_, _ = fmt.Fprintf(c.out, " %v\n", d)
}
}
if domains, ok := data["certs_pending_domains"].([]any); ok && len(domains) > 0 {
_, _ = fmt.Fprintf(c.out, " Pending:\n")
for _, d := range domains {
_, _ = fmt.Fprintf(c.out, " %v\n", d)
}
}
if domains, ok := data["certs_failed_domains"].(map[string]any); ok && len(domains) > 0 {
_, _ = fmt.Fprintf(c.out, " Failed:\n")
for d, errMsg := range domains {
_, _ = fmt.Fprintf(c.out, " %s: %v\n", d, errMsg)
}
}
c.printHealthClients(data)
}
func (c *Client) printHealthClients(data map[string]any) {
clients, ok := data["clients"].(map[string]any)
if !ok || len(clients) == 0 {
return
}
_, _ = fmt.Fprintf(c.out, "\n%-38s %-9s %-7s %-8s %-8s %-16s %s\n",
"ACCOUNT ID", "HEALTHY", "MGMT", "SIGNAL", "RELAYS", "PEERS (P2P/RLY)", "DEGRADED")
_, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110))
for accountID, v := range clients {
ch, ok := v.(map[string]any)
if !ok {
continue
}
healthy := boolIcon(ch["healthy"])
mgmt := boolIcon(ch["management_connected"])
signal := boolIcon(ch["signal_connected"])
relaysConn, _ := ch["relays_connected"].(float64)
relaysTotal, _ := ch["relays_total"].(float64)
relays := fmt.Sprintf("%d/%d", int(relaysConn), int(relaysTotal))
peersConnected, _ := ch["peers_connected"].(float64)
peersTotal, _ := ch["peers_total"].(float64)
peersP2P, _ := ch["peers_p2p"].(float64)
peersRelayed, _ := ch["peers_relayed"].(float64)
peersDegraded, _ := ch["peers_degraded"].(float64)
peers := fmt.Sprintf("%d/%d (%d/%d)", int(peersConnected), int(peersTotal), int(peersP2P), int(peersRelayed))
degraded := fmt.Sprintf("%d", int(peersDegraded))
_, _ = fmt.Fprintf(c.out, "%-38s %-9s %-7s %-8s %-8s %-16s %s", accountID, healthy, mgmt, signal, relays, peers, degraded)
if errMsg, ok := ch["error"].(string); ok && errMsg != "" {
_, _ = fmt.Fprintf(c.out, " (%s)", errMsg)
}
_, _ = fmt.Fprintln(c.out)
}
}
func boolIcon(v any) string {
b, ok := v.(bool)
if !ok {
return "?"
}
if b {
return "yes"
}
return "no"
}
// ListClients fetches the list of all clients.
func (c *Client) ListClients(ctx context.Context) error {
return c.fetchAndPrint(ctx, "/debug/clients", c.printClients)
}
func (c *Client) printClients(data map[string]any) {
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
_, _ = fmt.Fprintf(c.out, "Clients: %v\n\n", data["client_count"])
clients, ok := data["clients"].([]any)
if !ok || len(clients) == 0 {
_, _ = fmt.Fprintln(c.out, "No clients connected.")
return
}
_, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT")
_, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110))
for _, item := range clients {
c.printClientRow(item)
}
}
func (c *Client) printClientRow(item any) {
client, ok := item.(map[string]any)
if !ok {
return
}
domains := c.extractDomains(client)
hasClient := "no"
if hc, ok := client["has_client"].(bool); ok && hc {
hasClient = "yes"
}
_, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n",
client["account_id"],
client["age"],
domains,
hasClient,
)
}
func (c *Client) extractDomains(client map[string]any) string {
d, ok := client["domains"].([]any)
if !ok || len(d) == 0 {
return "-"
}
parts := make([]string, len(d))
for i, domain := range d {
parts[i] = fmt.Sprint(domain)
}
return strings.Join(parts, ", ")
}
// ClientStatus fetches the status of a specific client.
func (c *Client) ClientStatus(ctx context.Context, accountID string, filters StatusFilters) error {
params := url.Values{}
if len(filters.IPs) > 0 {
params.Set("filter-by-ips", strings.Join(filters.IPs, ","))
}
if len(filters.Names) > 0 {
params.Set("filter-by-names", strings.Join(filters.Names, ","))
}
if filters.Status != "" {
params.Set("filter-by-status", filters.Status)
}
if filters.ConnectionType != "" {
params.Set("filter-by-connection-type", filters.ConnectionType)
}
path := "/debug/clients/" + url.PathEscape(accountID)
if len(params) > 0 {
path += "?" + params.Encode()
}
return c.fetchAndPrint(ctx, path, c.printClientStatus)
}
func (c *Client) printClientStatus(data map[string]any) {
_, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"])
if status, ok := data["status"].(string); ok {
_, _ = fmt.Fprint(c.out, status)
}
}
// ClientSyncResponse fetches the sync response of a specific client.
func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error {
path := "/debug/clients/" + url.PathEscape(accountID) + "/syncresponse"
return c.fetchAndPrintJSON(ctx, path)
}
// PingTCP performs a TCP ping through a client.
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error {
params := url.Values{}
params.Set("host", host)
params.Set("port", fmt.Sprintf("%d", port))
if timeout != "" {
params.Set("timeout", timeout)
}
path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode())
return c.fetchAndPrint(ctx, path, c.printPingResult)
}
func (c *Client) printPingResult(data map[string]any) {
success, _ := data["success"].(bool)
if success {
_, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"])
_, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"])
} else {
_, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"])
c.printError(data)
}
}
// SetLogLevel sets the log level of a specific client.
func (c *Client) SetLogLevel(ctx context.Context, accountID, level string) error {
params := url.Values{}
params.Set("level", level)
path := fmt.Sprintf("/debug/clients/%s/loglevel?%s", url.PathEscape(accountID), params.Encode())
return c.fetchAndPrint(ctx, path, c.printLogLevelResult)
}
func (c *Client) printLogLevelResult(data map[string]any) {
success, _ := data["success"].(bool)
if success {
_, _ = fmt.Fprintf(c.out, "Log level set to: %v\n", data["level"])
} else {
_, _ = fmt.Fprintln(c.out, "Failed to set log level")
c.printError(data)
}
}
// StartClient starts a specific client.
func (c *Client) StartClient(ctx context.Context, accountID string) error {
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"
return c.fetchAndPrint(ctx, path, c.printStartResult)
}
func (c *Client) printStartResult(data map[string]any) {
success, _ := data["success"].(bool)
if success {
_, _ = fmt.Fprintln(c.out, "Client started")
} else {
_, _ = fmt.Fprintln(c.out, "Failed to start client")
c.printError(data)
}
}
// StopClient stops a specific client.
func (c *Client) StopClient(ctx context.Context, accountID string) error {
path := "/debug/clients/" + url.PathEscape(accountID) + "/stop"
return c.fetchAndPrint(ctx, path, c.printStopResult)
}
func (c *Client) printStopResult(data map[string]any) {
success, _ := data["success"].(bool)
if success {
_, _ = fmt.Fprintln(c.out, "Client stopped")
} else {
_, _ = fmt.Fprintln(c.out, "Failed to stop client")
c.printError(data)
}
}
func (c *Client) printError(data map[string]any) {
if errMsg, ok := data["error"].(string); ok {
_, _ = fmt.Fprintf(c.out, "Error: %s\n", errMsg)
}
}
func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error {
data, raw, err := c.fetch(ctx, path)
if err != nil {
return err
}
if c.jsonOutput {
return c.writeJSON(data)
}
if data != nil {
printer(data)
return nil
}
_, _ = fmt.Fprintln(c.out, string(raw))
return nil
}
func (c *Client) fetchAndPrintJSON(ctx context.Context, path string) error {
data, raw, err := c.fetch(ctx, path)
if err != nil {
return err
}
if data != nil {
return c.writeJSON(data)
}
_, _ = fmt.Fprintln(c.out, string(raw))
return nil
}
func (c *Client) writeJSON(data map[string]any) error {
enc := json.NewEncoder(c.out)
enc.SetIndent("", " ")
return enc.Encode(data)
}
func (c *Client) fetch(ctx context.Context, path string) (map[string]any, []byte, error) {
fullURL := c.baseURL + path
if !strings.Contains(path, "format=json") {
if strings.Contains(path, "?") {
fullURL += "&format=json"
} else {
fullURL += "?format=json"
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
if err != nil {
return nil, nil, fmt.Errorf("create request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, nil, fmt.Errorf("request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode >= 400 {
return nil, nil, fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var data map[string]any
if err := json.Unmarshal(body, &data); err != nil {
return nil, body, nil
}
return data, body, nil
}

View File

@@ -0,0 +1,71 @@
package debug
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
)
func TestPrintHealth_WithCertsAndClients(t *testing.T) {
var buf bytes.Buffer
c := NewClient("localhost:8444", false, &buf)
data := map[string]any{
"status": "ok",
"uptime": "1h30m",
"management_connected": true,
"all_clients_healthy": true,
"certs_total": float64(3),
"certs_ready": float64(2),
"certs_pending": float64(1),
"certs_failed": float64(0),
"certs_ready_domains": []any{"a.example.com", "b.example.com"},
"certs_pending_domains": []any{"c.example.com"},
"clients": map[string]any{
"acc-1": map[string]any{
"healthy": true,
"management_connected": true,
"signal_connected": true,
"relays_connected": float64(1),
"relays_total": float64(2),
"peers_connected": float64(3),
"peers_total": float64(5),
"peers_p2p": float64(2),
"peers_relayed": float64(1),
"peers_degraded": float64(0),
},
},
}
c.printHealth(data)
out := buf.String()
assert.Contains(t, out, "Status: ok")
assert.Contains(t, out, "Uptime: 1h30m")
assert.Contains(t, out, "yes") // management_connected
assert.Contains(t, out, "2 ready, 1 pending, 0 failed (3 total)")
assert.Contains(t, out, "a.example.com")
assert.Contains(t, out, "c.example.com")
assert.Contains(t, out, "acc-1")
}
func TestPrintHealth_Minimal(t *testing.T) {
var buf bytes.Buffer
c := NewClient("localhost:8444", false, &buf)
data := map[string]any{
"status": "ok",
"uptime": "5m",
"management_connected": false,
"all_clients_healthy": false,
}
c.printHealth(data)
out := buf.String()
assert.Contains(t, out, "Status: ok")
assert.Contains(t, out, "Uptime: 5m")
assert.NotContains(t, out, "Certificates")
assert.NotContains(t, out, "ACCOUNT ID")
}

View File

@@ -0,0 +1,712 @@
// Package debug provides HTTP debug endpoints for the proxy server.
package debug
import (
"cmp"
"context"
"embed"
"encoding/json"
"fmt"
"html/template"
"maps"
"net/http"
"slices"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
nbembed "github.com/netbirdio/netbird/client/embed"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/version"
)
//go:embed templates/*.html
var templateFS embed.FS
const defaultPingTimeout = 10 * time.Second
// formatDuration formats a duration with 2 decimal places using appropriate units.
func formatDuration(d time.Duration) string {
switch {
case d >= time.Hour:
return fmt.Sprintf("%.2fh", d.Hours())
case d >= time.Minute:
return fmt.Sprintf("%.2fm", d.Minutes())
case d >= time.Second:
return fmt.Sprintf("%.2fs", d.Seconds())
case d >= time.Millisecond:
return fmt.Sprintf("%.2fms", float64(d.Microseconds())/1000)
case d >= time.Microsecond:
return fmt.Sprintf("%.2fµs", float64(d.Nanoseconds())/1000)
default:
return fmt.Sprintf("%dns", d.Nanoseconds())
}
}
func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.AccountID {
return slices.Sorted(maps.Keys(m))
}
// clientProvider provides access to NetBird clients.
type clientProvider interface {
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
}
// healthChecker provides health probe state.
type healthChecker interface {
ReadinessProbe() bool
StartupProbe(ctx context.Context) bool
CheckClientsConnected(ctx context.Context) (bool, map[types.AccountID]health.ClientHealth)
}
type certStatus interface {
TotalDomains() int
PendingDomains() []string
ReadyDomains() []string
FailedDomains() map[string]string
}
// Handler provides HTTP debug endpoints.
type Handler struct {
provider clientProvider
health healthChecker
certStatus certStatus
logger *log.Logger
startTime time.Time
templates *template.Template
templateMu sync.RWMutex
}
// NewHandler creates a new debug handler.
func NewHandler(provider clientProvider, healthChecker healthChecker, logger *log.Logger) *Handler {
if logger == nil {
logger = log.StandardLogger()
}
h := &Handler{
provider: provider,
health: healthChecker,
logger: logger,
startTime: time.Now(),
}
if err := h.loadTemplates(); err != nil {
logger.Errorf("failed to load embedded templates: %v", err)
}
return h
}
// SetCertStatus sets the certificate status provider for ACME prefetch observability.
func (h *Handler) SetCertStatus(cs certStatus) {
h.certStatus = cs
}
func (h *Handler) loadTemplates() error {
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
if err != nil {
return fmt.Errorf("parse embedded templates: %w", err)
}
h.templateMu.Lock()
h.templates = tmpl
h.templateMu.Unlock()
return nil
}
func (h *Handler) getTemplates() *template.Template {
h.templateMu.RLock()
defer h.templateMu.RUnlock()
return h.templates
}
// ServeHTTP handles debug requests.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
wantJSON := r.URL.Query().Get("format") == "json" || strings.HasSuffix(path, "/json")
path = strings.TrimSuffix(path, "/json")
switch path {
case "/debug", "/debug/":
h.handleIndex(w, r, wantJSON)
case "/debug/clients":
h.handleListClients(w, r, wantJSON)
case "/debug/health":
h.handleHealth(w, r, wantJSON)
default:
if h.handleClientRoutes(w, r, path, wantJSON) {
return
}
http.NotFound(w, r)
}
}
func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, path string, wantJSON bool) bool {
if !strings.HasPrefix(path, "/debug/clients/") {
return false
}
rest := strings.TrimPrefix(path, "/debug/clients/")
parts := strings.SplitN(rest, "/", 2)
accountID := types.AccountID(parts[0])
if len(parts) == 1 {
h.handleClientStatus(w, r, accountID, wantJSON)
return true
}
switch parts[1] {
case "syncresponse":
h.handleClientSyncResponse(w, r, accountID, wantJSON)
case "tools":
h.handleClientTools(w, r, accountID)
case "pingtcp":
h.handlePingTCP(w, r, accountID)
case "loglevel":
h.handleLogLevel(w, r, accountID)
case "start":
h.handleClientStart(w, r, accountID)
case "stop":
h.handleClientStop(w, r, accountID)
default:
return false
}
return true
}
type failedDomain struct {
Domain string
Error string
}
type indexData struct {
Version string
Uptime string
ClientCount int
TotalDomains int
CertsTotal int
CertsReady int
CertsPending int
CertsFailed int
CertsPendingDomains []string
CertsReadyDomains []string
CertsFailedDomains []failedDomain
Clients []clientData
}
type clientData struct {
AccountID string
Domains string
Age string
Status string
}
func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
clients := h.provider.ListClientsForDebug()
sortedIDs := sortedAccountIDs(clients)
totalDomains := 0
for _, info := range clients {
totalDomains += info.DomainCount
}
var certsTotal, certsReady, certsPending, certsFailed int
var certsPendingDomains, certsReadyDomains []string
var certsFailedDomains map[string]string
if h.certStatus != nil {
certsTotal = h.certStatus.TotalDomains()
certsPendingDomains = h.certStatus.PendingDomains()
certsReadyDomains = h.certStatus.ReadyDomains()
certsFailedDomains = h.certStatus.FailedDomains()
certsReady = len(certsReadyDomains)
certsPending = len(certsPendingDomains)
certsFailed = len(certsFailedDomains)
}
if wantJSON {
clientsJSON := make([]map[string]interface{}, 0, len(clients))
for _, id := range sortedIDs {
info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{
"account_id": info.AccountID,
"domain_count": info.DomainCount,
"domains": info.Domains,
"has_client": info.HasClient,
"created_at": info.CreatedAt,
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
resp := map[string]interface{}{
"version": version.NetbirdVersion(),
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"total_domains": totalDomains,
"certs_total": certsTotal,
"certs_ready": certsReady,
"certs_pending": certsPending,
"certs_failed": certsFailed,
"clients": clientsJSON,
}
if len(certsPendingDomains) > 0 {
resp["certs_pending_domains"] = certsPendingDomains
}
if len(certsReadyDomains) > 0 {
resp["certs_ready_domains"] = certsReadyDomains
}
if len(certsFailedDomains) > 0 {
resp["certs_failed_domains"] = certsFailedDomains
}
h.writeJSON(w, resp)
return
}
sortedFailed := make([]failedDomain, 0, len(certsFailedDomains))
for d, e := range certsFailedDomains {
sortedFailed = append(sortedFailed, failedDomain{Domain: d, Error: e})
}
slices.SortFunc(sortedFailed, func(a, b failedDomain) int {
return cmp.Compare(a.Domain, b.Domain)
})
data := indexData{
Version: version.NetbirdVersion(),
Uptime: time.Since(h.startTime).Round(time.Second).String(),
ClientCount: len(clients),
TotalDomains: totalDomains,
CertsTotal: certsTotal,
CertsReady: certsReady,
CertsPending: certsPending,
CertsFailed: certsFailed,
CertsPendingDomains: certsPendingDomains,
CertsReadyDomains: certsReadyDomains,
CertsFailedDomains: sortedFailed,
Clients: make([]clientData, 0, len(clients)),
}
for _, id := range sortedIDs {
info := clients[id]
domains := info.Domains.SafeString()
if domains == "" {
domains = "-"
}
status := "No client"
if info.HasClient {
status = "Active"
}
data.Clients = append(data.Clients, clientData{
AccountID: string(info.AccountID),
Domains: domains,
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
Status: status,
})
}
h.renderTemplate(w, "index", data)
}
type clientsData struct {
Uptime string
Clients []clientData
}
func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
clients := h.provider.ListClientsForDebug()
sortedIDs := sortedAccountIDs(clients)
if wantJSON {
clientsJSON := make([]map[string]interface{}, 0, len(clients))
for _, id := range sortedIDs {
info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{
"account_id": info.AccountID,
"domain_count": info.DomainCount,
"domains": info.Domains,
"has_client": info.HasClient,
"created_at": info.CreatedAt,
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
h.writeJSON(w, map[string]interface{}{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"clients": clientsJSON,
})
return
}
data := clientsData{
Uptime: time.Since(h.startTime).Round(time.Second).String(),
Clients: make([]clientData, 0, len(clients)),
}
for _, id := range sortedIDs {
info := clients[id]
domains := info.Domains.SafeString()
if domains == "" {
domains = "-"
}
status := "No client"
if info.HasClient {
status = "Active"
}
data.Clients = append(data.Clients, clientData{
AccountID: string(info.AccountID),
Domains: domains,
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
Status: status,
})
}
h.renderTemplate(w, "clients", data)
}
type clientDetailData struct {
AccountID string
ActiveTab string
Content string
}
func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, accountID types.AccountID, wantJSON bool) {
client, ok := h.provider.GetClient(accountID)
if !ok {
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
return
}
fullStatus, err := client.Status()
if err != nil {
http.Error(w, "Error getting status: "+err.Error(), http.StatusInternalServerError)
return
}
// Parse filter parameters
query := r.URL.Query()
statusFilter := query.Get("filter-by-status")
connectionTypeFilter := query.Get("filter-by-connection-type")
var prefixNamesFilter []string
var prefixNamesFilterMap map[string]struct{}
if names := query.Get("filter-by-names"); names != "" {
prefixNamesFilter = strings.Split(names, ",")
prefixNamesFilterMap = make(map[string]struct{})
for _, name := range prefixNamesFilter {
prefixNamesFilterMap[strings.ToLower(strings.TrimSpace(name))] = struct{}{}
}
}
var ipsFilterMap map[string]struct{}
if ips := query.Get("filter-by-ips"); ips != "" {
ipsFilterMap = make(map[string]struct{})
for _, ip := range strings.Split(ips, ",") {
ipsFilterMap[strings.TrimSpace(ip)] = struct{}{}
}
}
pbStatus := nbstatus.ToProtoFullStatus(fullStatus)
overview := nbstatus.ConvertToStatusOutputOverview(
pbStatus,
false,
version.NetbirdVersion(),
statusFilter,
prefixNamesFilter,
prefixNamesFilterMap,
ipsFilterMap,
connectionTypeFilter,
"",
)
if wantJSON {
h.writeJSON(w, map[string]interface{}{
"account_id": accountID,
"status": overview.FullDetailSummary(),
})
return
}
data := clientDetailData{
AccountID: string(accountID),
ActiveTab: "status",
Content: overview.FullDetailSummary(),
}
h.renderTemplate(w, "clientDetail", data)
}
func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) {
client, ok := h.provider.GetClient(accountID)
if !ok {
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
return
}
syncResp, err := client.GetLatestSyncResponse()
if err != nil {
http.Error(w, "Error getting sync response: "+err.Error(), http.StatusInternalServerError)
return
}
if syncResp == nil {
http.Error(w, "No sync response available for client: "+string(accountID), http.StatusNotFound)
return
}
opts := protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: true,
Indent: " ",
AllowPartial: true,
}
jsonBytes, err := opts.Marshal(syncResp)
if err != nil {
http.Error(w, "Error marshaling sync response: "+err.Error(), http.StatusInternalServerError)
return
}
if wantJSON {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(jsonBytes)
return
}
data := clientDetailData{
AccountID: string(accountID),
ActiveTab: "syncresponse",
Content: string(jsonBytes),
}
h.renderTemplate(w, "clientDetail", data)
}
type toolsData struct {
AccountID string
}
func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, accountID types.AccountID) {
_, ok := h.provider.GetClient(accountID)
if !ok {
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
return
}
data := toolsData{
AccountID: string(accountID),
}
h.renderTemplate(w, "tools", data)
}
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
return
}
host := r.URL.Query().Get("host")
portStr := r.URL.Query().Get("port")
if host == "" || portStr == "" {
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
return
}
port, err := strconv.Atoi(portStr)
if err != nil || port < 1 || port > 65535 {
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
return
}
timeout := defaultPingTimeout
if t := r.URL.Query().Get("timeout"); t != "" {
if d, err := time.ParseDuration(t); err == nil {
timeout = d
}
}
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
address := fmt.Sprintf("%s:%d", host, port)
start := time.Now()
conn, err := client.Dial(ctx, "tcp", address)
if err != nil {
h.writeJSON(w, map[string]interface{}{
"success": false,
"host": host,
"port": port,
"error": err.Error(),
})
return
}
if err := conn.Close(); err != nil {
h.logger.Debugf("close tcp ping connection: %v", err)
}
latency := time.Since(start)
h.writeJSON(w, map[string]interface{}{
"success": true,
"host": host,
"port": port,
"latency_ms": latency.Milliseconds(),
"latency": formatDuration(latency),
})
}
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
return
}
level := r.URL.Query().Get("level")
if level == "" {
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
return
}
if err := client.SetLogLevel(level); err != nil {
h.writeJSON(w, map[string]interface{}{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
"success": true,
"level": level,
})
}
const clientActionTimeout = 30 * time.Second
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
return
}
ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout)
defer cancel()
if err := client.Start(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
"success": true,
"message": "client started",
})
}
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
return
}
ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout)
defer cancel()
if err := client.Stop(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
"success": true,
"message": "client stopped",
})
}
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) {
if !wantJSON {
http.Redirect(w, r, "/debug", http.StatusSeeOther)
return
}
uptime := time.Since(h.startTime).Round(10 * time.Millisecond).String()
ready := h.health.ReadinessProbe()
allHealthy, clientHealth := h.health.CheckClientsConnected(r.Context())
status := "ok"
// No clients is not a health issue; only degrade when actual clients are unhealthy
if !ready || (!allHealthy && len(clientHealth) > 0) {
status = "degraded"
}
var certsTotal, certsReady, certsPending, certsFailed int
var certsPendingDomains, certsReadyDomains []string
var certsFailedDomains map[string]string
if h.certStatus != nil {
certsTotal = h.certStatus.TotalDomains()
certsPendingDomains = h.certStatus.PendingDomains()
certsReadyDomains = h.certStatus.ReadyDomains()
certsFailedDomains = h.certStatus.FailedDomains()
certsReady = len(certsReadyDomains)
certsPending = len(certsPendingDomains)
certsFailed = len(certsFailedDomains)
}
resp := map[string]any{
"status": status,
"uptime": uptime,
"management_connected": ready,
"all_clients_healthy": allHealthy,
"certs_total": certsTotal,
"certs_ready": certsReady,
"certs_pending": certsPending,
"certs_failed": certsFailed,
"clients": clientHealth,
}
if len(certsPendingDomains) > 0 {
resp["certs_pending_domains"] = certsPendingDomains
}
if len(certsReadyDomains) > 0 {
resp["certs_ready_domains"] = certsReadyDomains
}
if len(certsFailedDomains) > 0 {
resp["certs_failed_domains"] = certsFailedDomains
}
h.writeJSON(w, resp)
}
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl := h.getTemplates()
if tmpl == nil {
http.Error(w, "Templates not loaded", http.StatusInternalServerError)
return
}
if err := tmpl.ExecuteTemplate(w, name, data); err != nil {
h.logger.Errorf("execute template %s: %v", name, err)
http.Error(w, "Template error", http.StatusInternalServerError)
}
}
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
if err := enc.Encode(v); err != nil {
h.logger.Errorf("encode JSON response: %v", err)
}
}

View File

@@ -0,0 +1,101 @@
{{define "style"}}
body {
font-family: monospace;
margin: 20px;
background: #1a1a1a;
color: #eee;
}
a {
color: #6cf;
}
h1, h2, h3 {
color: #fff;
}
.info {
color: #aaa;
}
table {
border-collapse: collapse;
margin: 10px 0;
}
th, td {
border: 1px solid #444;
padding: 8px;
text-align: left;
}
th {
background: #333;
}
.nav {
margin-bottom: 20px;
}
.nav a {
margin-right: 15px;
padding: 8px 16px;
background: #333;
text-decoration: none;
border-radius: 4px;
}
.nav a.active {
background: #6cf;
color: #000;
}
pre {
background: #222;
padding: 15px;
border-radius: 4px;
overflow-x: auto;
white-space: pre-wrap;
}
input, select, textarea {
background: #333;
color: #eee;
border: 1px solid #555;
padding: 8px;
border-radius: 4px;
font-family: monospace;
}
input:focus, select:focus, textarea:focus {
outline: none;
border-color: #6cf;
}
button {
background: #6cf;
color: #000;
border: none;
padding: 8px 16px;
border-radius: 4px;
cursor: pointer;
font-family: monospace;
}
button:hover {
background: #5be;
}
button:disabled {
background: #555;
color: #888;
cursor: not-allowed;
}
.form-group {
margin-bottom: 15px;
}
.form-group label {
display: block;
margin-bottom: 5px;
color: #aaa;
}
.form-row {
display: flex;
gap: 10px;
align-items: flex-end;
}
.result {
margin-top: 20px;
}
.success {
color: #5f5;
}
.error {
color: #f55;
}
{{end}}

View File

@@ -0,0 +1,19 @@
{{define "clientDetail"}}
<!DOCTYPE html>
<html lang="en">
<head>
<title>Client {{.AccountID}}</title>
<style>{{template "style"}}</style>
</head>
<body>
<h1>Client: {{.AccountID}}</h1>
<div class="nav">
<a href="/debug">&larr; Back</a>
<a href="/debug/clients/{{.AccountID}}/tools"{{if eq .ActiveTab "tools"}} class="active"{{end}}>Tools</a>
<a href="/debug/clients/{{.AccountID}}"{{if eq .ActiveTab "status"}} class="active"{{end}}>Status</a>
<a href="/debug/clients/{{.AccountID}}/syncresponse"{{if eq .ActiveTab "syncresponse"}} class="active"{{end}}>Sync Response</a>
</div>
<pre>{{.Content}}</pre>
</body>
</html>
{{end}}

View File

@@ -0,0 +1,33 @@
{{define "clients"}}
<!DOCTYPE html>
<html lang="en">
<head>
<title>Clients</title>
<style>{{template "style"}}</style>
</head>
<body>
<h1>All Clients</h1>
<p class="info">Uptime: {{.Uptime}} | <a href="/debug">&larr; Back</a></p>
{{if .Clients}}
<table>
<tr>
<th>Account ID</th>
<th>Domains</th>
<th>Age</th>
<th>Status</th>
</tr>
{{range .Clients}}
<tr>
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
<td>{{.Domains}}</td>
<td>{{.Age}}</td>
<td>{{.Status}}</td>
</tr>
{{end}}
</table>
{{else}}
<p>No clients connected</p>
{{end}}
</body>
</html>
{{end}}

View File

@@ -0,0 +1,58 @@
{{define "index"}}
<!DOCTYPE html>
<html lang="en">
<head>
<title>NetBird Proxy Debug</title>
<style>{{template "style"}}</style>
</head>
<body>
<h1>NetBird Proxy Debug</h1>
<p class="info">Version: {{.Version}} | Uptime: {{.Uptime}}</p>
<h2>Certificates: {{.CertsReady}} ready, {{.CertsPending}} pending, {{.CertsFailed}} failed ({{.CertsTotal}} total)</h2>
{{if .CertsReadyDomains}}
<details>
<summary>Ready domains ({{.CertsReady}})</summary>
<ul>{{range .CertsReadyDomains}}<li>{{.}}</li>{{end}}</ul>
</details>
{{end}}
{{if .CertsPendingDomains}}
<details open>
<summary>Pending domains ({{.CertsPending}})</summary>
<ul>{{range .CertsPendingDomains}}<li>{{.}}</li>{{end}}</ul>
</details>
{{end}}
{{if .CertsFailedDomains}}
<details open>
<summary>Failed domains ({{.CertsFailed}})</summary>
<ul>{{range .CertsFailedDomains}}<li>{{.Domain}}: {{.Error}}</li>{{end}}</ul>
</details>
{{end}}
<h2>Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})</h2>
{{if .Clients}}
<table>
<tr>
<th>Account ID</th>
<th>Domains</th>
<th>Age</th>
<th>Status</th>
</tr>
{{range .Clients}}
<tr>
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
<td>{{.Domains}}</td>
<td>{{.Age}}</td>
<td>{{.Status}}</td>
</tr>
{{end}}
</table>
{{else}}
<p>No clients connected</p>
{{end}}
<h2>Endpoints</h2>
<ul>
<li><a href="/debug/clients">/debug/clients</a> - all clients detail</li>
</ul>
<p class="info">Add ?format=json or /json suffix for JSON output</p>
</body>
</html>
{{end}}

View File

@@ -0,0 +1,142 @@
{{define "tools"}}
<!DOCTYPE html>
<html lang="en">
<head>
<title>Client {{.AccountID}} - Tools</title>
<style>{{template "style"}}</style>
</head>
<body>
<h1>Client: {{.AccountID}}</h1>
<div class="nav">
<a href="/debug">&larr; Back</a>
<a href="/debug/clients/{{.AccountID}}/tools" class="active">Tools</a>
<a href="/debug/clients/{{.AccountID}}">Status</a>
<a href="/debug/clients/{{.AccountID}}/syncresponse">Sync Response</a>
</div>
<h2>Client Control</h2>
<div class="form-row">
<div class="form-group">
<span>&nbsp;</span>
<button onclick="startClient()">Start</button>
</div>
<div class="form-group">
<span>&nbsp;</span>
<button onclick="stopClient()">Stop</button>
</div>
</div>
<div id="client-result" class="result"></div>
<h2>Log Level</h2>
<div class="form-row">
<div class="form-group">
<label for="log-level">Level</label>
<select id="log-level" style="width: 120px;">
<option value="trace">trace</option>
<option value="debug">debug</option>
<option value="info">info</option>
<option value="warn" selected>warn</option>
<option value="error">error</option>
</select>
</div>
<div class="form-group">
<span>&nbsp;</span>
<button onclick="setLogLevel()">Set Level</button>
</div>
</div>
<div id="log-result" class="result"></div>
<h2>TCP Ping</h2>
<div class="form-row">
<div class="form-group">
<label for="tcp-host">Host</label>
<input type="text" id="tcp-host" placeholder="100.0.0.1 or hostname.netbird.cloud" style="width: 300px;">
</div>
<div class="form-group">
<label for="tcp-port">Port</label>
<input type="number" id="tcp-port" placeholder="80" style="width: 80px;">
</div>
<div class="form-group">
<span>&nbsp;</span>
<button onclick="doTcpPing()">Connect</button>
</div>
</div>
<div id="tcp-result" class="result"></div>
<script>
const accountID = "{{.AccountID}}";
async function startClient() {
const resultDiv = document.getElementById('client-result');
resultDiv.innerHTML = '<span class="info">Starting client...</span>';
try {
const resp = await fetch('/debug/clients/' + accountID + '/start');
const data = await resp.json();
if (data.success) {
resultDiv.innerHTML = '<span class="success">✓ ' + data.message + '</span>';
} else {
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
}
} catch (e) {
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
}
}
async function stopClient() {
const resultDiv = document.getElementById('client-result');
resultDiv.innerHTML = '<span class="info">Stopping client...</span>';
try {
const resp = await fetch('/debug/clients/' + accountID + '/stop');
const data = await resp.json();
if (data.success) {
resultDiv.innerHTML = '<span class="success">✓ ' + data.message + '</span>';
} else {
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
}
} catch (e) {
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
}
}
async function setLogLevel() {
const level = document.getElementById('log-level').value;
const resultDiv = document.getElementById('log-result');
resultDiv.innerHTML = '<span class="info">Setting log level...</span>';
try {
const resp = await fetch('/debug/clients/' + accountID + '/loglevel?level=' + level);
const data = await resp.json();
if (data.success) {
resultDiv.innerHTML = '<span class="success">✓ Log level set to: ' + data.level + '</span>';
} else {
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
}
} catch (e) {
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
}
}
async function doTcpPing() {
const host = document.getElementById('tcp-host').value;
const port = document.getElementById('tcp-port').value;
if (!host || !port) {
alert('Host and port required');
return;
}
const resultDiv = document.getElementById('tcp-result');
resultDiv.innerHTML = '<span class="info">Connecting...</span>';
try {
const resp = await fetch('/debug/clients/' + accountID + '/pingtcp?host=' + encodeURIComponent(host) + '&port=' + port);
const data = await resp.json();
if (data.success) {
resultDiv.innerHTML = '<span class="success">✓ ' + data.host + ':' + data.port + ' connected in ' + data.latency + '</span>';
} else {
resultDiv.innerHTML = '<span class="error">✗ ' + data.host + ':' + data.port + ': ' + data.error + '</span>';
}
} catch (e) {
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
}
}
</script>
</body>
</html>
{{end}}

View File

@@ -0,0 +1,20 @@
//go:build !unix
package flock
import (
"context"
"os"
)
// Lock is a no-op on non-Unix platforms. Returns (nil, nil) to indicate
// that no lock was acquired; callers must treat a nil file as "proceed
// without lock" rather than "lock held by someone else."
func Lock(_ context.Context, _ string) (*os.File, error) {
return nil, nil //nolint:nilnil // intentional: nil file signals locking unsupported on this platform
}
// Unlock is a no-op on non-Unix platforms.
func Unlock(_ *os.File) error {
return nil
}

View File

@@ -0,0 +1,79 @@
//go:build unix
package flock
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLockUnlock(t *testing.T) {
lockPath := filepath.Join(t.TempDir(), "test.lock")
f, err := Lock(context.Background(), lockPath)
require.NoError(t, err)
require.NotNil(t, f)
_, err = os.Stat(lockPath)
assert.NoError(t, err, "lock file should exist")
err = Unlock(f)
assert.NoError(t, err)
}
func TestUnlockNil(t *testing.T) {
err := Unlock(nil)
assert.NoError(t, err, "unlocking nil should be a no-op")
}
func TestLockRespectsContext(t *testing.T) {
lockPath := filepath.Join(t.TempDir(), "test.lock")
f1, err := Lock(context.Background(), lockPath)
require.NoError(t, err)
defer func() { require.NoError(t, Unlock(f1)) }()
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
defer cancel()
_, err = Lock(ctx, lockPath)
assert.ErrorIs(t, err, context.DeadlineExceeded)
}
func TestLockBlocks(t *testing.T) {
lockPath := filepath.Join(t.TempDir(), "test.lock")
f1, err := Lock(context.Background(), lockPath)
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
start := time.Now()
var elapsed time.Duration
go func() {
defer wg.Done()
f2, err := Lock(context.Background(), lockPath)
elapsed = time.Since(start)
assert.NoError(t, err)
if f2 != nil {
assert.NoError(t, Unlock(f2))
}
}()
// Hold the lock for 200ms, then release.
time.Sleep(200 * time.Millisecond)
require.NoError(t, Unlock(f1))
wg.Wait()
assert.GreaterOrEqual(t, elapsed, 150*time.Millisecond,
"Lock should have blocked for at least ~200ms")
}

View File

@@ -0,0 +1,77 @@
//go:build unix
// Package flock provides best-effort advisory file locking using flock(2).
//
// This is used for cross-replica coordination (e.g. preventing duplicate
// ACME requests). Note that flock(2) does NOT work reliably on NFS volumes:
// on NFSv3 it depends on the NLM daemon, on NFSv4 Linux emulates it via
// fcntl locks with different semantics. Callers must treat lock failures
// as non-fatal and proceed without the lock.
package flock
import (
"context"
"errors"
"fmt"
"os"
"syscall"
"time"
log "github.com/sirupsen/logrus"
)
const retryInterval = 100 * time.Millisecond
// Lock acquires an exclusive advisory lock on the given file path.
// It creates the lock file if it does not exist. The lock attempt
// respects context cancellation by using non-blocking flock with polling.
// The caller must call Unlock with the returned *os.File when done.
func Lock(ctx context.Context, path string) (*os.File, error) {
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o600)
if err != nil {
return nil, fmt.Errorf("open lock file %s: %w", path, err)
}
timer := time.NewTimer(retryInterval)
defer timer.Stop()
for {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err == nil {
return f, nil
} else if !errors.Is(err, syscall.EWOULDBLOCK) {
if cerr := f.Close(); cerr != nil {
log.Debugf("close lock file %s: %v", path, cerr)
}
return nil, fmt.Errorf("acquire lock on %s: %w", path, err)
}
select {
case <-ctx.Done():
if cerr := f.Close(); cerr != nil {
log.Debugf("close lock file %s: %v", path, cerr)
}
return nil, ctx.Err()
case <-timer.C:
timer.Reset(retryInterval)
}
}
}
// Unlock releases the lock and closes the file.
func Unlock(f *os.File) error {
if f == nil {
return nil
}
defer func() {
if cerr := f.Close(); cerr != nil {
log.Debugf("close lock file: %v", cerr)
}
}()
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil {
return fmt.Errorf("release lock: %w", err)
}
return nil
}

View File

@@ -0,0 +1,48 @@
// Package grpc provides gRPC utilities for the proxy client.
package grpc
import (
"context"
"os"
"strconv"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// EnvProxyAllowInsecure controls whether the proxy token can be sent over non-TLS connections.
const EnvProxyAllowInsecure = "NB_PROXY_ALLOW_INSECURE"
var _ credentials.PerRPCCredentials = (*proxyAuthToken)(nil)
type proxyAuthToken struct {
token string
allowInsecure bool
}
func (t proxyAuthToken) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
return map[string]string{
"authorization": "Bearer " + t.token,
}, nil
}
// RequireTransportSecurity returns true by default to protect the token in transit.
// Set NB_PROXY_ALLOW_INSECURE=true to allow non-TLS connections (not recommended for production).
func (t proxyAuthToken) RequireTransportSecurity() bool {
return !t.allowInsecure
}
// WithProxyToken returns a DialOption that sets the proxy access token on each outbound RPC.
func WithProxyToken(token string) grpc.DialOption {
allowInsecure := false
if val := os.Getenv(EnvProxyAllowInsecure); val != "" {
parsed, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("invalid value for %s: %v", EnvProxyAllowInsecure, err)
} else {
allowInsecure = parsed
}
}
return grpc.WithPerRPCCredentials(proxyAuthToken{token: token, allowInsecure: allowInsecure})
}

View File

@@ -0,0 +1,405 @@
// Package health provides health probes for the proxy server.
package health
import (
"context"
"encoding/json"
"net"
"net/http"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy/internal/types"
)
const handshakeStaleThreshold = 5 * time.Minute
const (
maxConcurrentChecks = 3
maxClientCheckTimeout = 5 * time.Minute
)
// clientProvider provides access to NetBird clients for health checks.
type clientProvider interface {
ListClientsForStartup() map[types.AccountID]*embed.Client
}
// Checker tracks health state and provides probe endpoints.
type Checker struct {
logger *log.Logger
provider clientProvider
mu sync.RWMutex
managementConnected bool
initialSyncComplete bool
shuttingDown bool
// checkSem limits concurrent client health checks.
checkSem chan struct{}
// checkHealth checks the health of a single client.
// Defaults to checkClientHealth; overridable in tests.
checkHealth func(*embed.Client) ClientHealth
}
// ClientHealth represents the health status of a single NetBird client.
type ClientHealth struct {
Healthy bool `json:"healthy"`
ManagementConnected bool `json:"management_connected"`
SignalConnected bool `json:"signal_connected"`
RelaysConnected int `json:"relays_connected"`
RelaysTotal int `json:"relays_total"`
PeersTotal int `json:"peers_total"`
PeersConnected int `json:"peers_connected"`
PeersP2P int `json:"peers_p2p"`
PeersRelayed int `json:"peers_relayed"`
PeersDegraded int `json:"peers_degraded"`
Error string `json:"error,omitempty"`
}
// ProbeResponse represents the JSON response for health probes.
type ProbeResponse struct {
Status string `json:"status"`
Checks map[string]bool `json:"checks,omitempty"`
Clients map[types.AccountID]ClientHealth `json:"clients,omitempty"`
}
// Server runs the health probe HTTP server on a dedicated port.
type Server struct {
server *http.Server
logger *log.Logger
checker *Checker
}
// SetManagementConnected updates the management connection state.
func (c *Checker) SetManagementConnected(connected bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.managementConnected = connected
}
// SetInitialSyncComplete marks that the initial mapping sync has completed.
func (c *Checker) SetInitialSyncComplete() {
c.mu.Lock()
defer c.mu.Unlock()
c.initialSyncComplete = true
}
// SetShuttingDown marks the server as shutting down.
// This causes ReadinessProbe to return false so load balancers stop routing traffic.
func (c *Checker) SetShuttingDown() {
c.mu.Lock()
defer c.mu.Unlock()
c.shuttingDown = true
}
// CheckClientsConnected verifies all clients are connected to management/signal/relay.
// Uses the provided context for timeout/cancellation, with a maximum bound of maxClientCheckTimeout.
// Limits concurrent checks via semaphore.
func (c *Checker) CheckClientsConnected(ctx context.Context) (bool, map[types.AccountID]ClientHealth) {
// Apply upper bound timeout in case parent context has no deadline
ctx, cancel := context.WithTimeout(ctx, maxClientCheckTimeout)
defer cancel()
clients := c.provider.ListClientsForStartup()
// No clients is not a health issue
if len(clients) == 0 {
return true, make(map[types.AccountID]ClientHealth)
}
type result struct {
accountID types.AccountID
health ClientHealth
}
resultsCh := make(chan result, len(clients))
var wg sync.WaitGroup
for accountID, client := range clients {
wg.Add(1)
go func(id types.AccountID, cl *embed.Client) {
defer wg.Done()
// Acquire semaphore
select {
case c.checkSem <- struct{}{}:
defer func() { <-c.checkSem }()
case <-ctx.Done():
resultsCh <- result{id, ClientHealth{Healthy: false, Error: ctx.Err().Error()}}
return
}
resultsCh <- result{id, c.checkHealth(cl)}
}(accountID, client)
}
go func() {
wg.Wait()
close(resultsCh)
}()
results := make(map[types.AccountID]ClientHealth)
allHealthy := true
for r := range resultsCh {
results[r.accountID] = r.health
if !r.health.Healthy {
allHealthy = false
}
}
return allHealthy, results
}
// LivenessProbe returns true if the process is alive.
// This should always return true if we can respond.
func (c *Checker) LivenessProbe() bool {
return true
}
// ReadinessProbe returns true if the server can accept traffic.
func (c *Checker) ReadinessProbe() bool {
c.mu.RLock()
defer c.mu.RUnlock()
if c.shuttingDown {
return false
}
return c.managementConnected
}
// StartupProbe checks if initial startup is complete.
// Checks management connection, initial sync, and all client health directly.
// Uses the provided context for timeout/cancellation.
func (c *Checker) StartupProbe(ctx context.Context) bool {
c.mu.RLock()
mgmt := c.managementConnected
sync := c.initialSyncComplete
c.mu.RUnlock()
if !mgmt || !sync {
return false
}
// Check all clients are connected to management/signal/relay.
// Returns true when no clients exist (nothing to check).
allHealthy, _ := c.CheckClientsConnected(ctx)
return allHealthy
}
// Handler returns an http.Handler for health probe endpoints.
func (c *Checker) Handler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/healthz/live", c.handleLiveness)
mux.HandleFunc("/healthz/ready", c.handleReadiness)
mux.HandleFunc("/healthz/startup", c.handleStartup)
mux.HandleFunc("/healthz", c.handleFull)
return mux
}
func (c *Checker) handleLiveness(w http.ResponseWriter, r *http.Request) {
if c.LivenessProbe() {
c.writeProbeResponse(w, http.StatusOK, "ok", nil, nil)
return
}
c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", nil, nil)
}
func (c *Checker) handleReadiness(w http.ResponseWriter, r *http.Request) {
c.mu.RLock()
checks := map[string]bool{
"management_connected": c.managementConnected,
}
c.mu.RUnlock()
if c.ReadinessProbe() {
c.writeProbeResponse(w, http.StatusOK, "ok", checks, nil)
return
}
c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", checks, nil)
}
func (c *Checker) handleStartup(w http.ResponseWriter, r *http.Request) {
c.mu.RLock()
mgmt := c.managementConnected
syncComplete := c.initialSyncComplete
c.mu.RUnlock()
allClientsHealthy, clientHealth := c.CheckClientsConnected(r.Context())
checks := map[string]bool{
"management_connected": mgmt,
"initial_sync_complete": syncComplete,
"all_clients_healthy": allClientsHealthy,
}
ready := mgmt && syncComplete && allClientsHealthy
if ready {
c.writeProbeResponse(w, http.StatusOK, "ok", checks, clientHealth)
return
}
c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", checks, clientHealth)
}
func (c *Checker) handleFull(w http.ResponseWriter, r *http.Request) {
c.mu.RLock()
mgmt := c.managementConnected
sync := c.initialSyncComplete
c.mu.RUnlock()
allClientsHealthy, clientHealth := c.CheckClientsConnected(r.Context())
checks := map[string]bool{
"management_connected": mgmt,
"initial_sync_complete": sync,
"all_clients_healthy": allClientsHealthy,
}
status := "ok"
statusCode := http.StatusOK
if !c.ReadinessProbe() {
status = "fail"
statusCode = http.StatusServiceUnavailable
}
c.writeProbeResponse(w, statusCode, status, checks, clientHealth)
}
func (c *Checker) writeProbeResponse(w http.ResponseWriter, statusCode int, status string, checks map[string]bool, clients map[types.AccountID]ClientHealth) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
resp := ProbeResponse{
Status: status,
Checks: checks,
Clients: clients,
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
c.logger.Debugf("write health response: %v", err)
}
}
// ListenAndServe starts the health probe server.
func (s *Server) ListenAndServe() error {
s.logger.Infof("starting health probe server on %s", s.server.Addr)
return s.server.ListenAndServe()
}
// Serve starts the health probe server on the given listener.
func (s *Server) Serve(l net.Listener) error {
s.logger.Infof("starting health probe server on %s", l.Addr())
return s.server.Serve(l)
}
// Shutdown gracefully shuts down the health probe server.
func (s *Server) Shutdown(ctx context.Context) error {
return s.server.Shutdown(ctx)
}
// NewChecker creates a new health checker.
func NewChecker(logger *log.Logger, provider clientProvider) *Checker {
if logger == nil {
logger = log.StandardLogger()
}
return &Checker{
logger: logger,
provider: provider,
checkSem: make(chan struct{}, maxConcurrentChecks),
checkHealth: checkClientHealth,
}
}
// NewServer creates a new health probe server.
// If metricsHandler is non-nil, it is mounted at /metrics on the same port.
func NewServer(addr string, checker *Checker, logger *log.Logger, metricsHandler http.Handler) *Server {
if logger == nil {
logger = log.StandardLogger()
}
handler := checker.Handler()
if metricsHandler != nil {
mux := http.NewServeMux()
mux.Handle("/metrics", metricsHandler)
mux.Handle("/", handler)
handler = mux
}
return &Server{
server: &http.Server{
Addr: addr,
Handler: handler,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
},
logger: logger,
checker: checker,
}
}
func checkClientHealth(client *embed.Client) ClientHealth {
if client == nil {
return ClientHealth{
Healthy: false,
Error: "client not initialized",
}
}
status, err := client.Status()
if err != nil {
return ClientHealth{
Healthy: false,
Error: err.Error(),
}
}
// Count only rel:// and rels:// relays (not stun/turn)
var relayCount, relaysConnected int
for _, relay := range status.Relays {
if !strings.HasPrefix(relay.URI, "rel://") && !strings.HasPrefix(relay.URI, "rels://") {
continue
}
relayCount++
if relay.Err == nil {
relaysConnected++
}
}
// Count peer connection stats
now := time.Now()
var peersConnected, peersP2P, peersRelayed, peersDegraded int
for _, p := range status.Peers {
if p.ConnStatus != embed.PeerStatusConnected {
continue
}
peersConnected++
if p.Relayed {
peersRelayed++
} else {
peersP2P++
}
if p.LastWireguardHandshake.IsZero() || now.Sub(p.LastWireguardHandshake) > handshakeStaleThreshold {
peersDegraded++
}
}
// Client is healthy if connected to management, signal, and at least one relay (if any are defined)
healthy := status.ManagementState.Connected &&
status.SignalState.Connected &&
(relayCount == 0 || relaysConnected > 0)
return ClientHealth{
Healthy: healthy,
ManagementConnected: status.ManagementState.Connected,
SignalConnected: status.SignalState.Connected,
RelaysConnected: relaysConnected,
RelaysTotal: relayCount,
PeersTotal: len(status.Peers),
PeersConnected: peersConnected,
PeersP2P: peersP2P,
PeersRelayed: peersRelayed,
PeersDegraded: peersDegraded,
}
}

View File

@@ -0,0 +1,473 @@
package health
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy/internal/types"
)
type mockClientProvider struct {
clients map[types.AccountID]*embed.Client
}
func (m *mockClientProvider) ListClientsForStartup() map[types.AccountID]*embed.Client {
return m.clients
}
// newTestChecker creates a checker with a mock health function for testing.
// The health function returns the provided ClientHealth for every client.
func newTestChecker(provider clientProvider, healthResult ClientHealth) *Checker {
c := NewChecker(nil, provider)
c.checkHealth = func(_ *embed.Client) ClientHealth {
return healthResult
}
return c
}
func TestChecker_LivenessProbe(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
// Liveness should always return true if we can respond.
assert.True(t, checker.LivenessProbe())
}
func TestChecker_ReadinessProbe(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
// Initially not ready (management not connected).
assert.False(t, checker.ReadinessProbe())
// After management connects, should be ready.
checker.SetManagementConnected(true)
assert.True(t, checker.ReadinessProbe())
// If management disconnects, should not be ready.
checker.SetManagementConnected(false)
assert.False(t, checker.ReadinessProbe())
}
// TestStartupProbe_EmptyServiceList covers the scenario where management has
// no services configured for this proxy. The proxy should become ready once
// management is connected and the initial sync completes, even with zero clients.
func TestStartupProbe_EmptyServiceList(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
// No management connection = not ready.
assert.False(t, checker.StartupProbe(context.Background()))
// Management connected but no sync = not ready.
checker.SetManagementConnected(true)
assert.False(t, checker.StartupProbe(context.Background()))
// Management + sync complete + no clients = ready.
checker.SetInitialSyncComplete()
assert.True(t, checker.StartupProbe(context.Background()))
}
// TestStartupProbe_WithUnhealthyClients verifies that when services exist
// and clients have been created but are not yet fully connected (to mgmt,
// signal, relays), the startup probe does NOT pass.
func TestStartupProbe_WithUnhealthyClients(t *testing.T) {
provider := &mockClientProvider{
clients: map[types.AccountID]*embed.Client{
"account-1": nil, // concrete client not needed; checkHealth is mocked
"account-2": nil,
},
}
checker := newTestChecker(provider, ClientHealth{Healthy: false, Error: "not connected yet"})
checker.SetManagementConnected(true)
checker.SetInitialSyncComplete()
assert.False(t, checker.StartupProbe(context.Background()),
"startup probe must not pass when clients are unhealthy")
}
// TestStartupProbe_WithHealthyClients verifies that once all clients are
// connected and healthy, the startup probe passes.
func TestStartupProbe_WithHealthyClients(t *testing.T) {
provider := &mockClientProvider{
clients: map[types.AccountID]*embed.Client{
"account-1": nil,
"account-2": nil,
},
}
checker := newTestChecker(provider, ClientHealth{
Healthy: true,
ManagementConnected: true,
SignalConnected: true,
RelaysConnected: 1,
RelaysTotal: 1,
})
checker.SetManagementConnected(true)
checker.SetInitialSyncComplete()
assert.True(t, checker.StartupProbe(context.Background()),
"startup probe must pass when all clients are healthy")
}
// TestStartupProbe_MixedHealthClients verifies that if any single client is
// unhealthy, the startup probe fails (all-or-nothing).
func TestStartupProbe_MixedHealthClients(t *testing.T) {
provider := &mockClientProvider{
clients: map[types.AccountID]*embed.Client{
"healthy-account": nil,
"unhealthy-account": nil,
},
}
checker := NewChecker(nil, provider)
checker.checkHealth = func(cl *embed.Client) ClientHealth {
// We identify accounts by their position in the map iteration; since we
// can't control map order, make exactly one unhealthy via counter.
return ClientHealth{Healthy: false}
}
checker.SetManagementConnected(true)
checker.SetInitialSyncComplete()
assert.False(t, checker.StartupProbe(context.Background()),
"startup probe must fail if any client is unhealthy")
}
// TestStartupProbe_RequiresAllConditions ensures that each individual
// prerequisite (management, sync, clients) is necessary. The probe must not
// pass if any one is missing.
func TestStartupProbe_RequiresAllConditions(t *testing.T) {
provider := &mockClientProvider{
clients: map[types.AccountID]*embed.Client{
"account-1": nil,
},
}
t.Run("no management", func(t *testing.T) {
checker := newTestChecker(provider, ClientHealth{Healthy: true})
checker.SetInitialSyncComplete()
// management NOT connected
assert.False(t, checker.StartupProbe(context.Background()))
})
t.Run("no sync", func(t *testing.T) {
checker := newTestChecker(provider, ClientHealth{Healthy: true})
checker.SetManagementConnected(true)
// sync NOT complete
assert.False(t, checker.StartupProbe(context.Background()))
})
t.Run("unhealthy client", func(t *testing.T) {
checker := newTestChecker(provider, ClientHealth{Healthy: false})
checker.SetManagementConnected(true)
checker.SetInitialSyncComplete()
assert.False(t, checker.StartupProbe(context.Background()))
})
t.Run("all conditions met", func(t *testing.T) {
checker := newTestChecker(provider, ClientHealth{Healthy: true})
checker.SetManagementConnected(true)
checker.SetInitialSyncComplete()
assert.True(t, checker.StartupProbe(context.Background()))
})
}
// TestStartupProbe_ConcurrentAccess runs the startup probe from many
// goroutines simultaneously to check for races.
func TestStartupProbe_ConcurrentAccess(t *testing.T) {
provider := &mockClientProvider{
clients: map[types.AccountID]*embed.Client{
"account-1": nil,
"account-2": nil,
},
}
checker := newTestChecker(provider, ClientHealth{Healthy: true})
checker.SetManagementConnected(true)
checker.SetInitialSyncComplete()
var wg sync.WaitGroup
const goroutines = 50
results := make([]bool, goroutines)
for i := range goroutines {
wg.Add(1)
go func(idx int) {
defer wg.Done()
results[idx] = checker.StartupProbe(context.Background())
}(i)
}
wg.Wait()
for i, r := range results {
assert.True(t, r, "goroutine %d got unexpected result", i)
}
}
// TestStartupProbe_CancelledContext verifies that a cancelled context causes
// the probe to report unhealthy when client checks are needed.
func TestStartupProbe_CancelledContext(t *testing.T) {
t.Run("no management bypasses context", func(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
ctx, cancel := context.WithCancel(context.Background())
cancel()
// Should be false because management isn't connected, context is irrelevant.
assert.False(t, checker.StartupProbe(ctx))
})
t.Run("with clients and cancelled context", func(t *testing.T) {
provider := &mockClientProvider{
clients: map[types.AccountID]*embed.Client{
"account-1": nil,
},
}
checker := NewChecker(nil, provider)
// Use the real checkHealth path — a cancelled context should cause
// the semaphore acquisition to fail, reporting unhealthy.
checker.SetManagementConnected(true)
checker.SetInitialSyncComplete()
ctx, cancel := context.WithCancel(context.Background())
cancel()
assert.False(t, checker.StartupProbe(ctx),
"cancelled context must result in unhealthy when clients exist")
})
}
// TestHandler_Startup_EmptyServiceList verifies the HTTP startup endpoint
// returns 200 when management is connected, sync is complete, and there are
// no services/clients.
func TestHandler_Startup_EmptyServiceList(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
checker.SetInitialSyncComplete()
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/startup", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "ok", resp.Status)
assert.True(t, resp.Checks["management_connected"])
assert.True(t, resp.Checks["initial_sync_complete"])
assert.True(t, resp.Checks["all_clients_healthy"])
assert.Empty(t, resp.Clients)
}
// TestHandler_Startup_WithUnhealthyClients verifies that the HTTP startup
// endpoint returns 503 when clients exist but are not yet healthy.
func TestHandler_Startup_WithUnhealthyClients(t *testing.T) {
provider := &mockClientProvider{
clients: map[types.AccountID]*embed.Client{
"account-1": nil,
},
}
checker := newTestChecker(provider, ClientHealth{Healthy: false, Error: "starting"})
checker.SetManagementConnected(true)
checker.SetInitialSyncComplete()
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/startup", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "fail", resp.Status)
assert.True(t, resp.Checks["management_connected"])
assert.True(t, resp.Checks["initial_sync_complete"])
assert.False(t, resp.Checks["all_clients_healthy"])
require.Contains(t, resp.Clients, types.AccountID("account-1"))
assert.Equal(t, "starting", resp.Clients["account-1"].Error)
}
// TestHandler_Startup_WithHealthyClients verifies the HTTP startup endpoint
// returns 200 once clients are healthy.
func TestHandler_Startup_WithHealthyClients(t *testing.T) {
provider := &mockClientProvider{
clients: map[types.AccountID]*embed.Client{
"account-1": nil,
},
}
checker := newTestChecker(provider, ClientHealth{
Healthy: true,
ManagementConnected: true,
SignalConnected: true,
RelaysConnected: 1,
RelaysTotal: 1,
})
checker.SetManagementConnected(true)
checker.SetInitialSyncComplete()
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/startup", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "ok", resp.Status)
assert.True(t, resp.Checks["all_clients_healthy"])
}
// TestHandler_Startup_NotComplete verifies the startup handler returns 503
// when prerequisites aren't met.
func TestHandler_Startup_NotComplete(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/startup", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "fail", resp.Status)
}
func TestChecker_Handler_Liveness(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "ok", resp.Status)
}
func TestChecker_Handler_Readiness_NotReady(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "fail", resp.Status)
assert.False(t, resp.Checks["management_connected"])
}
func TestChecker_Handler_Readiness_Ready(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "ok", resp.Status)
assert.True(t, resp.Checks["management_connected"])
}
func TestChecker_Handler_Full(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "ok", resp.Status)
assert.NotNil(t, resp.Checks)
// Clients may be empty map when no clients exist.
assert.Empty(t, resp.Clients)
}
func TestChecker_SetShuttingDown(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
assert.True(t, checker.ReadinessProbe(), "should be ready before shutdown")
checker.SetShuttingDown()
assert.False(t, checker.ReadinessProbe(), "should not be ready after shutdown")
}
func TestChecker_Handler_Readiness_ShuttingDown(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
checker.SetShuttingDown()
handler := checker.Handler()
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
var resp ProbeResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
assert.Equal(t, "fail", resp.Status)
}
func TestNewServer_WithMetricsHandler(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("metrics"))
})
srv := NewServer(":0", checker, nil, metricsHandler)
require.NotNil(t, srv)
// Verify health endpoint still works through the mux.
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
rec := httptest.NewRecorder()
srv.server.Handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
// Verify metrics endpoint is mounted.
req = httptest.NewRequest(http.MethodGet, "/metrics", nil)
rec = httptest.NewRecorder()
srv.server.Handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "metrics", rec.Body.String())
}
func TestNewServer_WithoutMetricsHandler(t *testing.T) {
checker := NewChecker(nil, &mockClientProvider{})
checker.SetManagementConnected(true)
srv := NewServer(":0", checker, nil, nil)
require.NotNil(t, srv)
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
rec := httptest.NewRecorder()
srv.server.Handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}

281
proxy/internal/k8s/lease.go Normal file
View File

@@ -0,0 +1,281 @@
// Package k8s provides a lightweight Kubernetes API client for coordination
// Leases. It uses raw HTTP calls against the mounted service account
// credentials, avoiding a dependency on client-go.
package k8s
import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
)
const (
saTokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" //nolint:gosec
saNamespacePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
saCACertPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt"
leaseAPIPath = "/apis/coordination.k8s.io/v1"
)
// ErrConflict is returned when a Lease update fails due to a
// resourceVersion mismatch (another writer updated the object first).
var ErrConflict = errors.New("conflict: resource version mismatch")
// Lease represents a coordination.k8s.io/v1 Lease object with only the
// fields needed for distributed locking.
type Lease struct {
APIVersion string `json:"apiVersion"`
Kind string `json:"kind"`
Metadata LeaseMetadata `json:"metadata"`
Spec LeaseSpec `json:"spec"`
}
// LeaseMetadata holds the standard k8s object metadata fields used by Leases.
type LeaseMetadata struct {
Name string `json:"name"`
Namespace string `json:"namespace,omitempty"`
ResourceVersion string `json:"resourceVersion,omitempty"`
Annotations map[string]string `json:"annotations,omitempty"`
}
// LeaseSpec holds the Lease specification fields.
type LeaseSpec struct {
HolderIdentity *string `json:"holderIdentity"`
LeaseDurationSeconds *int32 `json:"leaseDurationSeconds,omitempty"`
AcquireTime *MicroTime `json:"acquireTime"`
RenewTime *MicroTime `json:"renewTime"`
}
// MicroTime wraps time.Time with Kubernetes MicroTime JSON formatting.
type MicroTime struct {
time.Time
}
const microTimeFormat = "2006-01-02T15:04:05.000000Z"
// MarshalJSON implements json.Marshaler with k8s MicroTime format.
func (t *MicroTime) MarshalJSON() ([]byte, error) {
return json.Marshal(t.UTC().Format(microTimeFormat))
}
// UnmarshalJSON implements json.Unmarshaler with k8s MicroTime format.
func (t *MicroTime) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
if s == "" {
t.Time = time.Time{}
return nil
}
parsed, err := time.Parse(microTimeFormat, s)
if err != nil {
return fmt.Errorf("parse MicroTime %q: %w", s, err)
}
t.Time = parsed
return nil
}
// LeaseClient talks to the Kubernetes coordination API using raw HTTP.
type LeaseClient struct {
baseURL string
namespace string
httpClient *http.Client
}
// NewLeaseClient creates a client that authenticates via the pod's
// mounted service account. It reads the namespace and CA certificate
// at construction time (they don't rotate) but reads the bearer token
// fresh on each request (tokens rotate).
func NewLeaseClient() (*LeaseClient, error) {
host := os.Getenv("KUBERNETES_SERVICE_HOST")
port := os.Getenv("KUBERNETES_SERVICE_PORT")
if host == "" || port == "" {
return nil, fmt.Errorf("KUBERNETES_SERVICE_HOST/PORT not set")
}
ns, err := os.ReadFile(saNamespacePath)
if err != nil {
return nil, fmt.Errorf("read namespace from %s: %w", saNamespacePath, err)
}
caCert, err := os.ReadFile(saCACertPath)
if err != nil {
return nil, fmt.Errorf("read CA cert from %s: %w", saCACertPath, err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("parse CA certificate from %s", saCACertPath)
}
return &LeaseClient{
baseURL: fmt.Sprintf("https://%s:%s", host, port),
namespace: strings.TrimSpace(string(ns)),
httpClient: &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: pool,
},
},
},
}, nil
}
// Namespace returns the namespace this client operates in.
func (c *LeaseClient) Namespace() string {
return c.namespace
}
// Get retrieves a Lease by name. Returns (nil, nil) if the Lease does not exist.
func (c *LeaseClient) Get(ctx context.Context, name string) (*Lease, error) {
url := fmt.Sprintf("%s%s/namespaces/%s/leases/%s", c.baseURL, leaseAPIPath, c.namespace, name)
resp, err := c.doRequest(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode == http.StatusNotFound {
return nil, nil //nolint:nilnil
}
if resp.StatusCode != http.StatusOK {
return nil, c.readError(resp)
}
var lease Lease
if err := json.NewDecoder(resp.Body).Decode(&lease); err != nil {
return nil, fmt.Errorf("decode lease response: %w", err)
}
return &lease, nil
}
// Create creates a new Lease. Returns the created Lease with server-assigned
// fields like resourceVersion populated.
func (c *LeaseClient) Create(ctx context.Context, lease *Lease) (*Lease, error) {
url := fmt.Sprintf("%s%s/namespaces/%s/leases", c.baseURL, leaseAPIPath, c.namespace)
lease.APIVersion = "coordination.k8s.io/v1"
lease.Kind = "Lease"
if lease.Metadata.Namespace == "" {
lease.Metadata.Namespace = c.namespace
}
resp, err := c.doRequest(ctx, http.MethodPost, url, lease)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode == http.StatusConflict {
return nil, ErrConflict
}
if resp.StatusCode != http.StatusCreated {
return nil, c.readError(resp)
}
var created Lease
if err := json.NewDecoder(resp.Body).Decode(&created); err != nil {
return nil, fmt.Errorf("decode created lease: %w", err)
}
return &created, nil
}
// Update replaces a Lease. The lease.Metadata.ResourceVersion must match
// the current server value (optimistic concurrency). Returns ErrConflict
// on version mismatch.
func (c *LeaseClient) Update(ctx context.Context, lease *Lease) (*Lease, error) {
url := fmt.Sprintf("%s%s/namespaces/%s/leases/%s", c.baseURL, leaseAPIPath, c.namespace, lease.Metadata.Name)
lease.APIVersion = "coordination.k8s.io/v1"
lease.Kind = "Lease"
resp, err := c.doRequest(ctx, http.MethodPut, url, lease)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode == http.StatusConflict {
return nil, ErrConflict
}
if resp.StatusCode != http.StatusOK {
return nil, c.readError(resp)
}
var updated Lease
if err := json.NewDecoder(resp.Body).Decode(&updated); err != nil {
return nil, fmt.Errorf("decode updated lease: %w", err)
}
return &updated, nil
}
func (c *LeaseClient) doRequest(ctx context.Context, method, url string, body any) (*http.Response, error) {
token, err := readToken()
if err != nil {
return nil, fmt.Errorf("read service account token: %w", err)
}
var bodyReader io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal request body: %w", err)
}
bodyReader = bytes.NewReader(data)
}
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Accept", "application/json")
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
return c.httpClient.Do(req)
}
func readToken() (string, error) {
data, err := os.ReadFile(saTokenPath)
if err != nil {
return "", fmt.Errorf("read %s: %w", saTokenPath, err)
}
return strings.TrimSpace(string(data)), nil
}
func (c *LeaseClient) readError(resp *http.Response) error {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return fmt.Errorf("k8s API %s %d: %s", resp.Request.URL.Path, resp.StatusCode, string(body))
}
// LeaseNameForDomain returns a deterministic, DNS-label-safe Lease name
// for the given domain. The domain is hashed to avoid dots and length issues.
func LeaseNameForDomain(domain string) string {
h := sha256.Sum256([]byte(domain))
return "cert-lock-" + hex.EncodeToString(h[:8])
}
// InCluster reports whether the process is running inside a Kubernetes pod
// by checking for the KUBERNETES_SERVICE_HOST environment variable.
func InCluster() bool {
_, exists := os.LookupEnv("KUBERNETES_SERVICE_HOST")
return exists
}

View File

@@ -0,0 +1,102 @@
package k8s
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLeaseNameForDomain(t *testing.T) {
tests := []struct {
domain string
}{
{"example.com"},
{"app.example.com"},
{"another.domain.io"},
}
seen := make(map[string]string)
for _, tc := range tests {
name := LeaseNameForDomain(tc.domain)
assert.True(t, len(name) <= 63, "must be valid DNS label length")
assert.Regexp(t, `^cert-lock-[0-9a-f]{16}$`, name,
"must match expected format for domain %q", tc.domain)
// Same input produces same output.
assert.Equal(t, name, LeaseNameForDomain(tc.domain), "must be deterministic")
// Different domains produce different names.
if prev, ok := seen[name]; ok {
t.Errorf("collision: %q and %q both map to %s", prev, tc.domain, name)
}
seen[name] = tc.domain
}
}
func TestMicroTimeJSON(t *testing.T) {
ts := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC)
mt := &MicroTime{Time: ts}
data, err := json.Marshal(mt)
require.NoError(t, err)
assert.Equal(t, `"2024-06-15T10:30:00.000000Z"`, string(data))
var decoded MicroTime
require.NoError(t, json.Unmarshal(data, &decoded))
assert.True(t, ts.Equal(decoded.Time), "round-trip should preserve time")
}
func TestMicroTimeNullJSON(t *testing.T) {
// Null pointer serializes as JSON null via the Lease struct.
spec := LeaseSpec{
HolderIdentity: nil,
AcquireTime: nil,
RenewTime: nil,
}
data, err := json.Marshal(spec)
require.NoError(t, err)
assert.Contains(t, string(data), `"acquireTime":null`)
assert.Contains(t, string(data), `"renewTime":null`)
}
func TestLeaseJSONRoundTrip(t *testing.T) {
holder := "pod-abc"
dur := int32(300)
now := MicroTime{Time: time.Now().UTC().Truncate(time.Microsecond)}
original := Lease{
APIVersion: "coordination.k8s.io/v1",
Kind: "Lease",
Metadata: LeaseMetadata{
Name: "cert-lock-abcdef0123456789",
Namespace: "default",
ResourceVersion: "12345",
Annotations: map[string]string{
"netbird.io/domain": "app.example.com",
},
},
Spec: LeaseSpec{
HolderIdentity: &holder,
LeaseDurationSeconds: &dur,
AcquireTime: &now,
RenewTime: &now,
},
}
data, err := json.Marshal(original)
require.NoError(t, err)
var decoded Lease
require.NoError(t, json.Unmarshal(data, &decoded))
assert.Equal(t, original.Metadata.Name, decoded.Metadata.Name)
assert.Equal(t, original.Metadata.ResourceVersion, decoded.Metadata.ResourceVersion)
assert.Equal(t, *original.Spec.HolderIdentity, *decoded.Spec.HolderIdentity)
assert.Equal(t, *original.Spec.LeaseDurationSeconds, *decoded.Spec.LeaseDurationSeconds)
assert.True(t, original.Spec.AcquireTime.Equal(decoded.Spec.AcquireTime.Time))
}

View File

@@ -0,0 +1,149 @@
package metrics
import (
"net/http"
"strconv"
"time"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
type Metrics struct {
requestsTotal prometheus.Counter
activeRequests prometheus.Gauge
configuredDomains prometheus.Gauge
pathsPerDomain *prometheus.GaugeVec
requestDuration *prometheus.HistogramVec
backendDuration *prometheus.HistogramVec
}
func New(reg prometheus.Registerer) *Metrics {
promFactory := promauto.With(reg)
return &Metrics{
requestsTotal: promFactory.NewCounter(prometheus.CounterOpts{
Name: "netbird_proxy_requests_total",
Help: "Total number of requests made to the netbird proxy",
}),
activeRequests: promFactory.NewGauge(prometheus.GaugeOpts{
Name: "netbird_proxy_active_requests_count",
Help: "Current in-flight requests handled by the netbird proxy",
}),
configuredDomains: promFactory.NewGauge(prometheus.GaugeOpts{
Name: "netbird_proxy_domains_count",
Help: "Current number of domains configured on the netbird proxy",
}),
pathsPerDomain: promFactory.NewGaugeVec(
prometheus.GaugeOpts{
Name: "netbird_proxy_paths_count",
Help: "Current number of paths configured on the netbird proxy labelled by domain",
},
[]string{"domain"},
),
requestDuration: promFactory.NewHistogramVec(
prometheus.HistogramOpts{
Name: "netbird_proxy_request_duration_seconds",
Help: "Duration of requests made to the netbird proxy",
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
},
[]string{"status", "size", "method", "host", "path"},
),
backendDuration: promFactory.NewHistogramVec(prometheus.HistogramOpts{
Name: "netbird_proxy_backend_duration_seconds",
Help: "Duration of peer round trip time from the netbird proxy",
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
},
[]string{"status", "size", "method", "host", "path"},
),
}
}
type responseInterceptor struct {
http.ResponseWriter
status int
size int
}
func (w *responseInterceptor) WriteHeader(status int) {
w.status = status
w.ResponseWriter.WriteHeader(status)
}
func (w *responseInterceptor) Write(b []byte) (int, error) {
size, err := w.ResponseWriter.Write(b)
w.size += size
return size, err
}
func (m *Metrics) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m.requestsTotal.Inc()
m.activeRequests.Inc()
interceptor := &responseInterceptor{ResponseWriter: w}
start := time.Now()
next.ServeHTTP(interceptor, r)
duration := time.Since(start)
m.activeRequests.Desc()
m.requestDuration.With(prometheus.Labels{
"status": strconv.Itoa(interceptor.status),
"size": strconv.Itoa(interceptor.size),
"method": r.Method,
"host": r.Host,
"path": r.URL.Path,
}).Observe(duration.Seconds())
})
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(req *http.Request) (*http.Response, error) {
labels := prometheus.Labels{
"method": req.Method,
"host": req.Host,
// Fill potentially empty labels with default values to avoid cardinality issues.
"path": "/",
"status": "0",
"size": "0",
}
if req.URL != nil {
labels["path"] = req.URL.Path
}
start := time.Now()
res, err := next.RoundTrip(req)
duration := time.Since(start)
// Not all labels will be available if there was an error.
if res != nil {
labels["status"] = strconv.Itoa(res.StatusCode)
labels["size"] = strconv.Itoa(int(res.ContentLength))
}
m.backendDuration.With(labels).Observe(duration.Seconds())
return res, err
})
}
func (m *Metrics) AddMapping(mapping proxy.Mapping) {
m.configuredDomains.Inc()
m.pathsPerDomain.With(prometheus.Labels{
"domain": mapping.Host,
}).Set(float64(len(mapping.Paths)))
}
func (m *Metrics) RemoveMapping(mapping proxy.Mapping) {
m.configuredDomains.Dec()
m.pathsPerDomain.With(prometheus.Labels{
"domain": mapping.Host,
}).Set(0)
}

View File

@@ -0,0 +1,67 @@
package metrics_test
import (
"net/http"
"net/url"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/netbirdio/netbird/proxy/internal/metrics"
"github.com/prometheus/client_golang/prometheus"
)
type testRoundTripper struct {
response *http.Response
err error
}
func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return t.response, t.err
}
func TestMetrics_RoundTripper(t *testing.T) {
testResponse := http.Response{
StatusCode: http.StatusOK,
Body: http.NoBody,
}
tests := map[string]struct {
roundTripper http.RoundTripper
request *http.Request
response *http.Response
err error
}{
"ok": {
roundTripper: &testRoundTripper{response: &testResponse},
request: &http.Request{Method: "GET", URL: &url.URL{Path: "/foo"}},
response: &testResponse,
},
"nil url": {
roundTripper: &testRoundTripper{response: &testResponse},
request: &http.Request{Method: "GET", URL: nil},
response: &testResponse,
},
"nil response": {
roundTripper: &testRoundTripper{response: nil},
request: &http.Request{Method: "GET", URL: &url.URL{Path: "/foo"}},
},
}
m := metrics.New(prometheus.NewRegistry())
for name, test := range tests {
t.Run(name, func(t *testing.T) {
rt := m.RoundTripper(test.roundTripper)
res, err := rt.RoundTrip(test.request)
if res != nil && res.Body != nil {
defer res.Body.Close()
}
if diff := cmp.Diff(test.err, err); diff != "" {
t.Errorf("Incorrect error (-want +got):\n%s", diff)
}
if diff := cmp.Diff(test.response, res); diff != "" {
t.Errorf("Incorrect response (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -0,0 +1,187 @@
package proxy
import (
"context"
"sync"
"github.com/netbirdio/netbird/proxy/internal/types"
)
type requestContextKey string
const (
serviceIdKey requestContextKey = "serviceId"
accountIdKey requestContextKey = "accountId"
capturedDataKey requestContextKey = "capturedData"
)
// ResponseOrigin indicates where a response was generated.
type ResponseOrigin int
const (
// OriginBackend means the response came from the backend service.
OriginBackend ResponseOrigin = iota
// OriginNoRoute means the proxy had no matching host or path.
OriginNoRoute
// OriginProxyError means the proxy failed to reach the backend.
OriginProxyError
// OriginAuth means the proxy intercepted the request for authentication.
OriginAuth
)
func (o ResponseOrigin) String() string {
switch o {
case OriginNoRoute:
return "no_route"
case OriginProxyError:
return "proxy_error"
case OriginAuth:
return "auth"
default:
return "backend"
}
}
// CapturedData is a mutable struct that allows downstream handlers
// to pass data back up the middleware chain.
type CapturedData struct {
mu sync.RWMutex
RequestID string
ServiceId string
AccountId types.AccountID
Origin ResponseOrigin
ClientIP string
UserID string
AuthMethod string
}
// GetRequestID safely gets the request ID
func (c *CapturedData) GetRequestID() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.RequestID
}
// SetServiceId safely sets the service ID
func (c *CapturedData) SetServiceId(serviceId string) {
c.mu.Lock()
defer c.mu.Unlock()
c.ServiceId = serviceId
}
// GetServiceId safely gets the service ID
func (c *CapturedData) GetServiceId() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ServiceId
}
// SetAccountId safely sets the account ID
func (c *CapturedData) SetAccountId(accountId types.AccountID) {
c.mu.Lock()
defer c.mu.Unlock()
c.AccountId = accountId
}
// GetAccountId safely gets the account ID
func (c *CapturedData) GetAccountId() types.AccountID {
c.mu.RLock()
defer c.mu.RUnlock()
return c.AccountId
}
// SetOrigin safely sets the response origin
func (c *CapturedData) SetOrigin(origin ResponseOrigin) {
c.mu.Lock()
defer c.mu.Unlock()
c.Origin = origin
}
// GetOrigin safely gets the response origin
func (c *CapturedData) GetOrigin() ResponseOrigin {
c.mu.RLock()
defer c.mu.RUnlock()
return c.Origin
}
// SetClientIP safely sets the resolved client IP.
func (c *CapturedData) SetClientIP(ip string) {
c.mu.Lock()
defer c.mu.Unlock()
c.ClientIP = ip
}
// GetClientIP safely gets the resolved client IP.
func (c *CapturedData) GetClientIP() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ClientIP
}
// SetUserID safely sets the authenticated user ID.
func (c *CapturedData) SetUserID(userID string) {
c.mu.Lock()
defer c.mu.Unlock()
c.UserID = userID
}
// GetUserID safely gets the authenticated user ID.
func (c *CapturedData) GetUserID() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.UserID
}
// SetAuthMethod safely sets the authentication method used.
func (c *CapturedData) SetAuthMethod(method string) {
c.mu.Lock()
defer c.mu.Unlock()
c.AuthMethod = method
}
// GetAuthMethod safely gets the authentication method used.
func (c *CapturedData) GetAuthMethod() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.AuthMethod
}
// WithCapturedData adds a CapturedData struct to the context
func WithCapturedData(ctx context.Context, data *CapturedData) context.Context {
return context.WithValue(ctx, capturedDataKey, data)
}
// CapturedDataFromContext retrieves the CapturedData from context
func CapturedDataFromContext(ctx context.Context) *CapturedData {
v := ctx.Value(capturedDataKey)
data, ok := v.(*CapturedData)
if !ok {
return nil
}
return data
}
func withServiceId(ctx context.Context, serviceId string) context.Context {
return context.WithValue(ctx, serviceIdKey, serviceId)
}
func ServiceIdFromContext(ctx context.Context) string {
v := ctx.Value(serviceIdKey)
serviceId, ok := v.(string)
if !ok {
return ""
}
return serviceId
}
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
return context.WithValue(ctx, accountIdKey, accountId)
}
func AccountIdFromContext(ctx context.Context) types.AccountID {
v := ctx.Value(accountIdKey)
accountId, ok := v.(types.AccountID)
if !ok {
return ""
}
return accountId
}

View File

@@ -0,0 +1,130 @@
package proxy_test
import (
"crypto/rand"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/types"
)
type nopTransport struct{}
func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: http.NoBody,
}, nil
}
func BenchmarkServeHTTP(b *testing.B) {
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
rp.AddMapping(proxy.Mapping{
ID: rand.Text(),
AccountID: types.AccountID(rand.Text()),
Host: "app.example.com",
Paths: map[string]*url.URL{
"/": {
Scheme: "http",
Host: "10.0.0.1:8080",
},
},
})
req := httptest.NewRequest(http.MethodGet, "http://app.example.com", nil)
req.Host = "app.example.com"
req.RemoteAddr = "203.0.113.50:12345"
for b.Loop() {
rp.ServeHTTP(httptest.NewRecorder(), req)
}
}
func BenchmarkServeHTTPHostCount(b *testing.B) {
hostCounts := []int{1, 10, 100, 1_000, 10_000}
for _, hostCount := range hostCounts {
b.Run(fmt.Sprintf("hosts=%d", hostCount), func(b *testing.B) {
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
var target string
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(hostCount)))
if err != nil {
b.Fatal(err)
}
for i := range hostCount {
id := rand.Text()
host := fmt.Sprintf("%s.example.com", id)
if int64(i) == targetIndex.Int64() {
target = id
}
rp.AddMapping(proxy.Mapping{
ID: id,
AccountID: types.AccountID(rand.Text()),
Host: host,
Paths: map[string]*url.URL{
"/": {
Scheme: "http",
Host: "10.0.0.1:8080",
},
},
})
}
req := httptest.NewRequest(http.MethodGet, "http://"+target+"/", nil)
req.Host = target
req.RemoteAddr = "203.0.113.50:12345"
for b.Loop() {
rp.ServeHTTP(httptest.NewRecorder(), req)
}
})
}
}
func BenchmarkServeHTTPPathCount(b *testing.B) {
pathCounts := []int{1, 5, 10, 25, 50}
for _, pathCount := range pathCounts {
b.Run(fmt.Sprintf("paths=%d", pathCount), func(b *testing.B) {
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
var target string
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(pathCount)))
if err != nil {
b.Fatal(err)
}
paths := make(map[string]*url.URL, pathCount)
for i := range pathCount {
path := "/" + rand.Text()
if int64(i) == targetIndex.Int64() {
target = path
}
paths[path] = &url.URL{
Scheme: "http",
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
}
}
rp.AddMapping(proxy.Mapping{
ID: rand.Text(),
AccountID: types.AccountID(rand.Text()),
Host: "app.example.com",
Paths: paths,
})
req := httptest.NewRequest(http.MethodGet, "http://app.example.com"+target, nil)
req.Host = "app.example.com"
req.RemoteAddr = "203.0.113.50:12345"
for b.Loop() {
rp.ServeHTTP(httptest.NewRecorder(), req)
}
})
}
}

View File

@@ -0,0 +1,406 @@
package proxy
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/web"
)
type ReverseProxy struct {
transport http.RoundTripper
// forwardedProto overrides the X-Forwarded-Proto header value.
// Valid values: "auto" (detect from TLS), "http", "https".
forwardedProto string
// trustedProxies is a list of IP prefixes for trusted upstream proxies.
// When the direct connection comes from a trusted proxy, forwarding
// headers are preserved and appended to instead of being stripped.
trustedProxies []netip.Prefix
mappingsMux sync.RWMutex
mappings map[string]Mapping
logger *log.Logger
}
// NewReverseProxy configures a new NetBird ReverseProxy.
// This is a wrapper around an httputil.ReverseProxy set
// to dynamically route requests based on internal mapping
// between requested URLs and targets.
// The internal mappings can be modified using the AddMapping
// and RemoveMapping functions.
func NewReverseProxy(transport http.RoundTripper, forwardedProto string, trustedProxies []netip.Prefix, logger *log.Logger) *ReverseProxy {
if logger == nil {
logger = log.StandardLogger()
}
return &ReverseProxy{
transport: transport,
forwardedProto: forwardedProto,
trustedProxies: trustedProxies,
mappings: make(map[string]Mapping),
logger: logger,
}
}
func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
result, exists := p.findTargetForRequest(r)
if !exists {
if cd := CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(OriginNoRoute)
}
requestID := getRequestID(r)
web.ServeErrorPage(w, r, http.StatusNotFound, "Service Not Found",
"The requested service could not be found. Please check the URL, try refreshing, or check if the peer is running. If that doesn't work, see our documentation for help.",
requestID, web.ErrorStatus{Proxy: true, Destination: false})
return
}
// Set the serviceId in the context for later retrieval.
ctx := withServiceId(r.Context(), result.serviceID)
// Set the accountId in the context for later retrieval (for middleware).
ctx = withAccountId(ctx, result.accountID)
// Set the accountId in the context for the roundtripper to use.
ctx = roundtrip.WithAccountID(ctx, result.accountID)
// Also populate captured data if it exists (allows middleware to read after handler completes).
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
// pointer in the context, and mutate the struct here so outer middleware can read it.
if capturedData := CapturedDataFromContext(ctx); capturedData != nil {
capturedData.SetServiceId(result.serviceID)
capturedData.SetAccountId(result.accountID)
}
rp := &httputil.ReverseProxy{
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
Transport: p.transport,
ErrorHandler: proxyErrorHandler,
}
if result.rewriteRedirects {
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose
}
rp.ServeHTTP(w, r.WithContext(ctx))
}
// rewriteFunc returns a Rewrite function for httputil.ReverseProxy that rewrites
// inbound requests to target the backend service while setting security-relevant
// forwarding headers and stripping proxy authentication credentials.
// When passHostHeader is true, the original client Host header is preserved
// instead of being rewritten to the backend's address.
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool) func(r *httputil.ProxyRequest) {
return func(r *httputil.ProxyRequest) {
// Strip the matched path prefix from the incoming request path before
// SetURL joins it with the target's base path, avoiding path duplication.
if matchedPath != "" && matchedPath != "/" {
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
if r.Out.URL.Path == "" {
r.Out.URL.Path = "/"
}
r.Out.URL.RawPath = ""
}
r.SetURL(target)
if passHostHeader {
r.Out.Host = r.In.Host
} else {
r.Out.Host = target.Host
}
clientIP := extractClientIP(r.In.RemoteAddr)
if IsTrustedProxy(clientIP, p.trustedProxies) {
p.setTrustedForwardingHeaders(r, clientIP)
} else {
p.setUntrustedForwardingHeaders(r, clientIP)
}
stripSessionCookie(r)
stripSessionTokenQuery(r)
}
}
// rewriteLocationFunc returns a ModifyResponse function that rewrites Location
// headers in backend responses when they point to the backend's address,
// replacing them with the public-facing host and scheme.
func (p *ReverseProxy) rewriteLocationFunc(target *url.URL, matchedPath string, inReq *http.Request) func(*http.Response) error {
publicHost := inReq.Host
publicScheme := auth.ResolveProto(p.forwardedProto, inReq.TLS)
return func(resp *http.Response) error {
location := resp.Header.Get("Location")
if location == "" {
return nil
}
locURL, err := url.Parse(location)
if err != nil {
return fmt.Errorf("parse Location header %q: %w", location, err)
}
// Only rewrite absolute URLs that point to the backend.
if locURL.Host == "" || !hostsEqual(locURL, target) {
return nil
}
locURL.Host = publicHost
locURL.Scheme = publicScheme
// Re-add the stripped path prefix so the client reaches the correct route.
// TrimRight prevents double slashes when matchedPath has a trailing slash.
if matchedPath != "" && matchedPath != "/" {
locURL.Path = strings.TrimRight(matchedPath, "/") + "/" + strings.TrimLeft(locURL.Path, "/")
}
resp.Header.Set("Location", locURL.String())
return nil
}
}
// hostsEqual compares two URL authorities, normalizing default ports per
// RFC 3986 Section 6.2.3 (https://443 == https, http://80 == http).
func hostsEqual(a, b *url.URL) bool {
return normalizeHost(a) == normalizeHost(b)
}
// normalizeHost strips the port from a URL's Host field if it matches the
// scheme's default port (443 for https, 80 for http).
func normalizeHost(u *url.URL) string {
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
return u.Host
}
if (u.Scheme == "https" && port == "443") || (u.Scheme == "http" && port == "80") {
return host
}
return u.Host
}
// setTrustedForwardingHeaders appends to the existing forwarding header chain
// and preserves upstream-provided headers when the direct connection is from
// a trusted proxy.
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
// Append the direct connection IP to the existing X-Forwarded-For chain.
if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" {
r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP)
} else {
r.Out.Header.Set("X-Forwarded-For", clientIP)
}
// Preserve upstream X-Real-IP if present; otherwise resolve through the chain.
if realIP := r.In.Header.Get("X-Real-IP"); realIP != "" {
r.Out.Header.Set("X-Real-IP", realIP)
} else {
resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies)
r.Out.Header.Set("X-Real-IP", resolved)
}
// Preserve upstream X-Forwarded-Host if present.
if fwdHost := r.In.Header.Get("X-Forwarded-Host"); fwdHost != "" {
r.Out.Header.Set("X-Forwarded-Host", fwdHost)
} else {
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
}
// Trust upstream X-Forwarded-Proto; fall back to local resolution.
if fwdProto := r.In.Header.Get("X-Forwarded-Proto"); fwdProto != "" {
r.Out.Header.Set("X-Forwarded-Proto", fwdProto)
} else {
r.Out.Header.Set("X-Forwarded-Proto", auth.ResolveProto(p.forwardedProto, r.In.TLS))
}
// Trust upstream X-Forwarded-Port; fall back to local computation.
if fwdPort := r.In.Header.Get("X-Forwarded-Port"); fwdPort != "" {
r.Out.Header.Set("X-Forwarded-Port", fwdPort)
} else {
resolvedProto := r.Out.Header.Get("X-Forwarded-Proto")
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, resolvedProto))
}
}
// setUntrustedForwardingHeaders strips all incoming forwarding headers and
// sets them fresh based on the direct connection. This is the default
// behavior when no trusted proxies are configured or the direct connection
// is from an untrusted source.
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
r.Out.Header.Set("X-Forwarded-For", clientIP)
r.Out.Header.Set("X-Real-IP", clientIP)
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
r.Out.Header.Set("X-Forwarded-Proto", proto)
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
}
// stripSessionCookie removes the proxy's session cookie from the outgoing
// request while preserving all other cookies.
func stripSessionCookie(r *httputil.ProxyRequest) {
cookies := r.In.Cookies()
r.Out.Header.Del("Cookie")
for _, c := range cookies {
if c.Name != auth.SessionCookieName {
r.Out.AddCookie(c)
}
}
}
// stripSessionTokenQuery removes the OIDC session_token query parameter from
// the outgoing URL to prevent credential leakage to backends.
func stripSessionTokenQuery(r *httputil.ProxyRequest) {
q := r.Out.URL.Query()
if q.Has("session_token") {
q.Del("session_token")
r.Out.URL.RawQuery = q.Encode()
}
}
// extractClientIP extracts the IP address from an http.Request.RemoteAddr
// which is always in host:port format.
func extractClientIP(remoteAddr string) string {
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return ip
}
// extractForwardedPort returns the port from the Host header if present,
// otherwise defaults to the standard port for the resolved protocol.
func extractForwardedPort(host, resolvedProto string) string {
_, port, err := net.SplitHostPort(host)
if err == nil && port != "" {
return port
}
if resolvedProto == "https" {
return "443"
}
return "80"
}
// proxyErrorHandler handles errors from the reverse proxy and serves
// user-friendly error pages instead of raw error responses.
func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
if cd := CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(OriginProxyError)
}
requestID := getRequestID(r)
clientIP := getClientIP(r)
title, message, code, status := classifyProxyError(err)
log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err)
web.ServeErrorPage(w, r, code, title, message, requestID, status)
}
// getClientIP retrieves the resolved client IP from context.
func getClientIP(r *http.Request) string {
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
return capturedData.GetClientIP()
}
return ""
}
// getRequestID retrieves the request ID from context or returns empty string.
func getRequestID(r *http.Request) string {
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
return capturedData.GetRequestID()
}
return ""
}
// classifyProxyError determines the appropriate error title, message, HTTP
// status code, and component status based on the error type.
func classifyProxyError(err error) (title, message string, code int, status web.ErrorStatus) {
switch {
case errors.Is(err, context.DeadlineExceeded),
isNetTimeout(err):
return "Request Timeout",
"The request timed out while trying to reach the service. Please refresh the page and try again.",
http.StatusGatewayTimeout,
web.ErrorStatus{Proxy: true, Destination: false}
case errors.Is(err, context.Canceled):
return "Request Canceled",
"The request was canceled before it could be completed. Please refresh the page and try again.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Destination: false}
case errors.Is(err, roundtrip.ErrNoAccountID):
return "Configuration Error",
"The request could not be processed due to a configuration issue. Please refresh the page and try again.",
http.StatusInternalServerError,
web.ErrorStatus{Proxy: false, Destination: false}
case errors.Is(err, roundtrip.ErrNoPeerConnection),
errors.Is(err, roundtrip.ErrClientStartFailed):
return "Proxy Not Connected",
"The proxy is not connected to the NetBird network. Please try again later or contact your administrator.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: false, Destination: false}
case errors.Is(err, roundtrip.ErrTooManyInflight):
return "Service Overloaded",
"The service is currently handling too many requests. Please try again shortly.",
http.StatusServiceUnavailable,
web.ErrorStatus{Proxy: true, Destination: false}
case isConnectionRefused(err):
return "Service Unavailable",
"The connection to the service was refused. Please verify that the service is running and try again.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Destination: false}
case isHostUnreachable(err):
return "Peer Not Connected",
"The connection to the peer could not be established. Please ensure the peer is running and connected to the NetBird network.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Destination: false}
}
return "Connection Error",
"An unexpected error occurred while connecting to the service. Please try again later.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Destination: false}
}
// isConnectionRefused checks for connection refused errors by inspecting
// the inner error of a *net.OpError. This handles both standard net errors
// (where the inner error is a *os.SyscallError with "connection refused")
// and gVisor netstack errors ("connection was refused").
func isConnectionRefused(err error) bool {
return opErrorContains(err, "refused")
}
// isHostUnreachable checks for host/network unreachable errors by inspecting
// the inner error of a *net.OpError. Covers standard net ("no route to host",
// "network is unreachable") and gVisor ("host is unreachable", etc.).
func isHostUnreachable(err error) bool {
return opErrorContains(err, "unreachable") || opErrorContains(err, "no route to host")
}
// isNetTimeout checks whether the error is a network timeout using the
// net.Error interface.
func isNetTimeout(err error) bool {
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}
// opErrorContains extracts the inner error from a *net.OpError and checks
// whether its message contains the given substring. This handles gVisor
// netstack errors which wrap tcpip errors as plain strings rather than
// syscall.Errno values.
func opErrorContains(err error, substr string) bool {
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Err != nil {
return strings.Contains(opErr.Err.Error(), substr)
}
return false
}

View File

@@ -0,0 +1,966 @@
package proxy
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/netip"
"net/url"
"os"
"syscall"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/web"
)
func TestRewriteFunc_HostRewriting(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
t.Run("rewrites host to backend by default", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
})
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", true)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
assert.Equal(t, "public.example.com", pr.Out.Host,
"Host header should be the original client host")
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host,
"URL host (used for TLS/SNI) must still point to the backend")
})
}
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
"should be set to the connecting client IP")
})
t.Run("strips spoofed X-Forwarded-For from client", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
"spoofed XFF must be replaced, not appended to")
})
t.Run("strips spoofed X-Real-IP from client", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Real-IP", "10.0.0.1")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
"spoofed X-Real-IP must be replaced")
})
}
func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "myapp.example.com:8443", pr.Out.Header.Get("X-Forwarded-Host"))
})
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
rewrite(pr)
assert.Equal(t, "443", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("auto detects https from TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("auto detects http without TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
// No TLS, but forced to https
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("forced http proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "http"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
rewrite(pr)
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
})
}
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
t.Run("strips nb_session cookie", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token-here"})
rewrite(pr)
cookies := pr.Out.Cookies()
for _, c := range cookies {
assert.NotEqual(t, auth.SessionCookieName, c.Name,
"proxy session cookie must not be forwarded to backend")
}
})
t.Run("preserves other cookies", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token"})
pr.In.AddCookie(&http.Cookie{Name: "app_session", Value: "app-value"})
pr.In.AddCookie(&http.Cookie{Name: "tracking", Value: "track-value"})
rewrite(pr)
cookies := pr.Out.Cookies()
cookieNames := make([]string, 0, len(cookies))
for _, c := range cookies {
cookieNames = append(cookieNames, c.Name)
}
assert.Contains(t, cookieNames, "app_session", "non-proxy cookies should be preserved")
assert.Contains(t, cookieNames, "tracking", "non-proxy cookies should be preserved")
assert.NotContains(t, cookieNames, auth.SessionCookieName, "proxy cookie must be stripped")
})
t.Run("handles request with no cookies", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get("Cookie"))
})
}
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
t.Run("strips session_token query parameter", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
rewrite(pr)
assert.Empty(t, pr.Out.URL.Query().Get("session_token"),
"OIDC session token must be stripped from backend request")
assert.Equal(t, "keep", pr.Out.URL.Query().Get("other"),
"other query parameters must be preserved")
})
t.Run("preserves query when no session_token present", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/api?foo=bar&baz=qux", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "bar", pr.Out.URL.Query().Get("foo"))
assert.Equal(t, "qux", pr.Out.URL.Query().Get("baz"))
})
}
func TestRewriteFunc_URLRewriting(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080/app")
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "http", pr.Out.URL.Scheme)
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host)
assert.Equal(t, "/app/somepath", pr.Out.URL.Path,
"SetURL should join the target base path with the request path")
})
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false)
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "https", pr.Out.URL.Scheme)
assert.Equal(t, "backend.example.org:443", pr.Out.URL.Host)
assert.Equal(t, "/app/", pr.Out.URL.Path,
"matched path prefix should be stripped before joining with target path")
})
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false)
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/app/article/123", pr.Out.URL.Path,
"subpath after matched prefix should be preserved")
})
}
func TestExtractClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
expected string
}{
{"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"},
{"IPv6 with port", "[::1]:12345", "::1"},
{"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"},
{"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"},
{"IPv6 without brackets fallback", "::1", "::1"},
{"empty string fallback", "", ""},
{"public IP", "203.0.113.50:9999", "203.0.113.50"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr))
})
}
}
func TestExtractForwardedPort(t *testing.T) {
tests := []struct {
name string
host string
resolvedProto string
expected string
}{
{"explicit port in host", "example.com:8443", "https", "8443"},
{"explicit port overrides proto default", "example.com:9090", "http", "9090"},
{"no port defaults to 443 for https", "example.com", "https", "443"},
{"no port defaults to 80 for http", "example.com", "http", "80"},
{"IPv6 host with port", "[::1]:8080", "http", "8080"},
{"IPv6 host without port", "::1", "https", "443"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, extractForwardedPort(tt.host, tt.resolvedProto))
})
}
}
func TestRewriteFunc_TrustedProxy(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
trusted := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
rewrite(pr)
assert.Equal(t, "203.0.113.50, 10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"))
})
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
pr.In.Header.Set("X-Real-IP", "203.0.113.50")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"))
})
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
"should resolve real client through trusted chain")
})
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
rewrite(pr)
assert.Equal(t, "original.example.com", pr.Out.Header.Get("X-Forwarded-Host"))
})
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Proto", "https")
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Port", "8443")
rewrite(pr)
assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"),
"should use configured forwardedProto as fallback")
})
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
rewrite(pr)
assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host"))
})
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
pr.In.Header.Set("X-Real-IP", "evil")
pr.In.Header.Set("X-Forwarded-Host", "evil.example.com")
pr.In.Header.Set("X-Forwarded-Proto", "https")
pr.In.Header.Set("X-Forwarded-Port", "9999")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
"untrusted: XFF must be replaced")
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
"untrusted: X-Real-IP must be replaced")
assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host"),
"untrusted: host must be from direct connection")
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"),
"untrusted: proto must be locally resolved")
assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"),
"untrusted: port must be locally computed")
})
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
rewrite(pr)
assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"),
"nil trusted list: should strip and use RemoteAddr")
})
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
rewrite(pr)
assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"),
"no upstream XFF: should set direct connection IP")
})
}
// TestRewriteFunc_PathForwarding verifies what path the backend actually
// receives given different configurations. This simulates the full pipeline:
// management builds a target URL (with matching prefix baked into the path),
// then the proxy strips the prefix and SetURL re-joins with the target path.
func TestRewriteFunc_PathForwarding(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
// Simulate what ToProtoMapping does: target URL includes the matching
// prefix as its path component, so the proxy strips-then-re-adds.
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
// Management builds: path="/heise", target="https://heise.de:443/heise"
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/heise/", pr.Out.URL.Path,
"backend sees /heise/ because prefix is stripped then re-added by SetURL")
})
t.Run("subpath under prefix also preserved", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/heise/article/123", pr.Out.URL.Path,
"subpath is preserved on top of the re-added prefix")
})
// What the behavior WOULD be if target URL had no path (true stripping)
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/", pr.Out.URL.Path,
"without path in target URL, backend sees / (true prefix stripping)")
})
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/article/123", pr.Out.URL.Path,
"without path in target URL, prefix is truly stripped")
})
// Root path "/" — no stripping expected
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.com:443/")
rewrite := p.rewriteFunc(target, "/", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/heise", pr.Out.URL.Path,
"root path match must not strip anything")
})
}
func TestRewriteLocationFunc(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }
newReq := func(rawURL string) *http.Request {
t.Helper()
r := httptest.NewRequest(http.MethodGet, rawURL, nil)
parsed, _ := url.Parse(rawURL)
r.Host = parsed.Host
return r
}
run := func(p *ReverseProxy, matchedPath string, inReq *http.Request, location string) (*http.Response, error) {
t.Helper()
modifyResp := p.rewriteLocationFunc(target, matchedPath, inReq) //nolint:bodyclose
resp := &http.Response{Header: http.Header{}}
if location != "" {
resp.Header.Set("Location", location)
}
err := modifyResp(resp)
return resp, err
}
t.Run("rewrites Location pointing to backend", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/page"), //nolint:bodyclose
"http://backend.internal:8080/login")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/login", resp.Header.Get("Location"))
})
t.Run("does not rewrite Location pointing to other host", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"https://other.example.com/path")
require.NoError(t, err)
assert.Equal(t, "https://other.example.com/path", resp.Header.Get("Location"))
})
t.Run("does not rewrite relative Location", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"/dashboard")
require.NoError(t, err)
assert.Equal(t, "/dashboard", resp.Header.Get("Location"))
})
t.Run("re-adds stripped path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/api", newReq("https://public.example.com/api/users"), //nolint:bodyclose
"http://backend.internal:8080/users")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/api/users", resp.Header.Get("Location"))
})
t.Run("uses resolved proto for scheme", func(t *testing.T) {
resp, err := run(newProxy("auto"), "", newReq("http://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/path")
require.NoError(t, err)
assert.Equal(t, "http://public.example.com/path", resp.Header.Get("Location"))
})
t.Run("no-op when Location header is empty", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), "") //nolint:bodyclose
require.NoError(t, err)
assert.Empty(t, resp.Header.Get("Location"))
})
t.Run("does not prepend root path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/", newReq("https://public.example.com/login"), //nolint:bodyclose
"http://backend.internal:8080/login")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/login", resp.Header.Get("Location"))
})
// --- Edge cases: query parameters and fragments ---
t.Run("preserves query parameters", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/login?redirect=%2Fdashboard&lang=en")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/login?redirect=%2Fdashboard&lang=en", resp.Header.Get("Location"))
})
t.Run("preserves fragment", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/docs#section-2")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/docs#section-2", resp.Header.Get("Location"))
})
t.Run("preserves query parameters and fragment together", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/search?q=test&page=1#results")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/search?q=test&page=1#results", resp.Header.Get("Location"))
})
t.Run("preserves query parameters with path prefix re-added", func(t *testing.T) {
resp, err := run(newProxy("https"), "/api", newReq("https://public.example.com/api/search"), //nolint:bodyclose
"http://backend.internal:8080/search?q=hello")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/api/search?q=hello", resp.Header.Get("Location"))
})
// --- Edge cases: slash handling ---
t.Run("no double slash when matchedPath has trailing slash", func(t *testing.T) {
resp, err := run(newProxy("https"), "/api/", newReq("https://public.example.com/api/users"), //nolint:bodyclose
"http://backend.internal:8080/users")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/api/users", resp.Header.Get("Location"))
})
t.Run("backend redirect to root with path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/app", newReq("https://public.example.com/app/"), //nolint:bodyclose
"http://backend.internal:8080/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/app/", resp.Header.Get("Location"))
})
t.Run("backend redirect to root with trailing-slash path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/app/", newReq("https://public.example.com/app/"), //nolint:bodyclose
"http://backend.internal:8080/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/app/", resp.Header.Get("Location"))
})
t.Run("preserves trailing slash on redirect path", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/path/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/path/", resp.Header.Get("Location"))
})
t.Run("backend redirect to bare root", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/page"), //nolint:bodyclose
"http://backend.internal:8080/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/", resp.Header.Get("Location"))
})
// --- Edge cases: host/port matching ---
t.Run("does not rewrite when backend host matches but port differs", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:9090/other")
require.NoError(t, err)
assert.Equal(t, "http://backend.internal:9090/other", resp.Header.Get("Location"),
"Different port means different host authority, must not rewrite")
})
t.Run("rewrites when redirect omits default port matching target", func(t *testing.T) {
// Target is backend.internal:8080, redirect is to backend.internal (no port).
// These are different authorities, so should NOT rewrite.
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal/path")
require.NoError(t, err)
assert.Equal(t, "http://backend.internal/path", resp.Header.Get("Location"),
"backend.internal != backend.internal:8080, must not rewrite")
})
t.Run("rewrites when target has :443 but redirect omits it for https", func(t *testing.T) {
// Target: heise.de:443, redirect: https://heise.de/path (no :443 because it's default)
// Per RFC 3986, these are the same authority.
target443, _ := url.Parse("https://heise.de:443")
p := newProxy("https")
modifyResp := p.rewriteLocationFunc(target443, "", newReq("https://public.example.com/")) //nolint:bodyclose
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "https://heise.de/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/path", resp.Header.Get("Location"),
"heise.de:443 and heise.de are the same for https")
})
t.Run("rewrites when target has :80 but redirect omits it for http", func(t *testing.T) {
target80, _ := url.Parse("http://backend.local:80")
p := newProxy("http")
modifyResp := p.rewriteLocationFunc(target80, "", newReq("http://public.example.com/")) //nolint:bodyclose
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "http://backend.local/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "http://public.example.com/path", resp.Header.Get("Location"),
"backend.local:80 and backend.local are the same for http")
})
t.Run("rewrites when redirect has :443 but target omits it", func(t *testing.T) {
targetNoPort, _ := url.Parse("https://heise.de")
p := newProxy("https")
modifyResp := p.rewriteLocationFunc(targetNoPort, "", newReq("https://public.example.com/")) //nolint:bodyclose
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "https://heise.de:443/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/path", resp.Header.Get("Location"),
"heise.de and heise.de:443 are the same for https")
})
t.Run("does not conflate non-default ports", func(t *testing.T) {
target8443, _ := url.Parse("https://backend.internal:8443")
p := newProxy("https")
modifyResp := p.rewriteLocationFunc(target8443, "", newReq("https://public.example.com/")) //nolint:bodyclose
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "https://backend.internal/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "https://backend.internal/path", resp.Header.Get("Location"),
"backend.internal:8443 != backend.internal (port 443), must not rewrite")
})
// --- Edge cases: encoded paths ---
t.Run("preserves percent-encoded path segments", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/path%20with%20spaces/file%2Fname")
require.NoError(t, err)
loc := resp.Header.Get("Location")
assert.Contains(t, loc, "public.example.com")
parsed, err := url.Parse(loc)
require.NoError(t, err)
assert.Equal(t, "/path with spaces/file/name", parsed.Path)
})
t.Run("preserves encoded query parameters with path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/v1", newReq("https://public.example.com/v1/"), //nolint:bodyclose
"http://backend.internal:8080/redirect?url=http%3A%2F%2Fexample.com")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/v1/redirect?url=http%3A%2F%2Fexample.com", resp.Header.Get("Location"))
})
}
// newProxyRequest creates an httputil.ProxyRequest suitable for testing
// the Rewrite function. It simulates what httputil.ReverseProxy does internally:
// Out is a shallow clone of In with headers copied.
func newProxyRequest(t *testing.T, rawURL, remoteAddr string) *httputil.ProxyRequest {
t.Helper()
parsed, err := url.Parse(rawURL)
require.NoError(t, err)
in := httptest.NewRequest(http.MethodGet, rawURL, nil)
in.RemoteAddr = remoteAddr
in.Host = parsed.Host
out := in.Clone(in.Context())
out.Header = in.Header.Clone()
return &httputil.ProxyRequest{In: in, Out: out}
}
func TestClassifyProxyError(t *testing.T) {
tests := []struct {
name string
err error
wantTitle string
wantCode int
wantStatus web.ErrorStatus
}{
{
name: "context deadline exceeded",
err: context.DeadlineExceeded,
wantTitle: "Request Timeout",
wantCode: http.StatusGatewayTimeout,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "wrapped deadline exceeded",
err: fmt.Errorf("dial: %w", context.DeadlineExceeded),
wantTitle: "Request Timeout",
wantCode: http.StatusGatewayTimeout,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "context canceled",
err: context.Canceled,
wantTitle: "Request Canceled",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "no account ID",
err: roundtrip.ErrNoAccountID,
wantTitle: "Configuration Error",
wantCode: http.StatusInternalServerError,
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
},
{
name: "no peer connection",
err: fmt.Errorf("%w for account: abc", roundtrip.ErrNoPeerConnection),
wantTitle: "Proxy Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
},
{
name: "client not started",
err: fmt.Errorf("%w: %w", roundtrip.ErrClientStartFailed, errors.New("engine init failed")),
wantTitle: "Proxy Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
},
{
name: "syscall ECONNREFUSED via os.SyscallError",
err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED},
},
wantTitle: "Service Unavailable",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "gvisor connection was refused",
err: &net.OpError{
Op: "connect",
Net: "tcp",
Err: errors.New("connection was refused"),
},
wantTitle: "Service Unavailable",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "syscall EHOSTUNREACH via os.SyscallError",
err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &os.SyscallError{Syscall: "connect", Err: syscall.EHOSTUNREACH},
},
wantTitle: "Peer Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "syscall ENETUNREACH via os.SyscallError",
err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &os.SyscallError{Syscall: "connect", Err: syscall.ENETUNREACH},
},
wantTitle: "Peer Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "gvisor host is unreachable",
err: &net.OpError{
Op: "connect",
Net: "tcp",
Err: errors.New("host is unreachable"),
},
wantTitle: "Peer Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "gvisor network is unreachable",
err: &net.OpError{
Op: "connect",
Net: "tcp",
Err: errors.New("network is unreachable"),
},
wantTitle: "Peer Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "standard no route to host",
err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &os.SyscallError{Syscall: "connect", Err: syscall.EHOSTUNREACH},
},
wantTitle: "Peer Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "unknown error falls to default",
err: errors.New("something unexpected"),
wantTitle: "Connection Error",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
title, _, code, status := classifyProxyError(tt.err)
assert.Equal(t, tt.wantTitle, title, "title")
assert.Equal(t, tt.wantCode, code, "status code")
assert.Equal(t, tt.wantStatus, status, "component status")
})
}
}

View File

@@ -0,0 +1,84 @@
package proxy
import (
"net"
"net/http"
"net/url"
"sort"
"strings"
"github.com/netbirdio/netbird/proxy/internal/types"
)
type Mapping struct {
ID string
AccountID types.AccountID
Host string
Paths map[string]*url.URL
PassHostHeader bool
RewriteRedirects bool
}
type targetResult struct {
url *url.URL
matchedPath string
serviceID string
accountID types.AccountID
passHostHeader bool
rewriteRedirects bool
}
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) {
p.mappingsMux.RLock()
defer p.mappingsMux.RUnlock()
// Strip port from host if present (e.g., "external.test:8443" -> "external.test")
host := req.Host
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
m, exists := p.mappings[host]
if !exists {
p.logger.Debugf("no mapping found for host: %s", host)
return targetResult{}, false
}
// Sort paths by length (longest first) in a naive attempt to match the most specific route first.
paths := make([]string, 0, len(m.Paths))
for path := range m.Paths {
paths = append(paths, path)
}
sort.Slice(paths, func(i, j int) bool {
return len(paths[i]) > len(paths[j])
})
for _, path := range paths {
if strings.HasPrefix(req.URL.Path, path) {
target := m.Paths[path]
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target)
return targetResult{
url: target,
matchedPath: path,
serviceID: m.ID,
accountID: m.AccountID,
passHostHeader: m.PassHostHeader,
rewriteRedirects: m.RewriteRedirects,
}, true
}
}
p.logger.Debugf("no path match for host: %s, path: %s", host, req.URL.Path)
return targetResult{}, false
}
func (p *ReverseProxy) AddMapping(m Mapping) {
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
p.mappings[m.Host] = m
}
func (p *ReverseProxy) RemoveMapping(m Mapping) {
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
delete(p.mappings, m.Host)
}

View File

@@ -0,0 +1,60 @@
package proxy
import (
"net/netip"
"strings"
)
// IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes.
func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
if len(trusted) == 0 {
return false
}
addr, err := netip.ParseAddr(ipStr)
if err != nil {
return false
}
for _, prefix := range trusted {
if prefix.Contains(addr) {
return true
}
}
return false
}
// ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list.
// It walks the XFF chain right-to-left, skipping IPs that match trusted prefixes.
// The first untrusted IP is the real client.
//
// If the trusted list is empty or remoteAddr is not trusted, it returns the
// remoteAddr IP directly (ignoring any forwarding headers).
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
remoteIP := extractClientIP(remoteAddr)
if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) {
return remoteIP
}
if xff == "" {
return remoteIP
}
parts := strings.Split(xff, ",")
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if ip == "" {
continue
}
if !IsTrustedProxy(ip, trusted) {
return ip
}
}
// All IPs in XFF are trusted; return the leftmost as best guess.
if first := strings.TrimSpace(parts[0]); first != "" {
return first
}
return remoteIP
}

View File

@@ -0,0 +1,129 @@
package proxy
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsTrustedProxy(t *testing.T) {
trusted := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("fd00::/8"),
}
tests := []struct {
name string
ip string
trusted []netip.Prefix
want bool
}{
{"empty trusted list", "10.0.0.1", nil, false},
{"IP within /8 prefix", "10.1.2.3", trusted, true},
{"IP within /24 prefix", "192.168.1.100", trusted, true},
{"IP outside all prefixes", "203.0.113.50", trusted, false},
{"boundary IP just outside prefix", "192.168.2.1", trusted, false},
{"unparsable IP", "not-an-ip", trusted, false},
{"IPv6 in trusted range", "fd00::1", trusted, true},
{"IPv6 outside range", "2001:db8::1", trusted, false},
{"empty string", "", trusted, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, IsTrustedProxy(tt.ip, tt.trusted))
})
}
}
func TestResolveClientIP(t *testing.T) {
trusted := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/12"),
}
tests := []struct {
name string
remoteAddr string
xff string
trusted []netip.Prefix
want string
}{
{
name: "empty trusted list returns RemoteAddr",
remoteAddr: "203.0.113.50:9999",
xff: "1.2.3.4",
trusted: nil,
want: "203.0.113.50",
},
{
name: "untrusted RemoteAddr ignores XFF",
remoteAddr: "203.0.113.50:9999",
xff: "1.2.3.4, 10.0.0.1",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "trusted RemoteAddr with single client in XFF",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "trusted RemoteAddr walks past trusted entries in XFF",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50, 10.0.0.2, 172.16.0.5",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
remoteAddr: "10.0.0.1:5000",
xff: "",
trusted: trusted,
want: "10.0.0.1",
},
{
name: "all XFF IPs trusted returns leftmost",
remoteAddr: "10.0.0.1:5000",
xff: "10.0.0.2, 172.16.0.1, 10.0.0.3",
trusted: trusted,
want: "10.0.0.2",
},
{
name: "XFF with whitespace",
remoteAddr: "10.0.0.1:5000",
xff: " 203.0.113.50 , 10.0.0.2 ",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "XFF with empty segments",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50,,10.0.0.2",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "multi-hop with mixed trust",
remoteAddr: "10.0.0.1:5000",
xff: "8.8.8.8, 203.0.113.50, 172.16.0.1",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "RemoteAddr without port",
remoteAddr: "10.0.0.1",
xff: "203.0.113.50",
trusted: trusted,
want: "203.0.113.50",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, ResolveClientIP(tt.remoteAddr, tt.xff, tt.trusted))
})
}
}

View File

@@ -0,0 +1,575 @@
package roundtrip
import (
"context"
"errors"
"fmt"
"net/http"
"sync"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/embed"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const deviceNamePrefix = "ingress-proxy-"
// backendKey identifies a backend by its host:port from the target URL.
type backendKey = string
var (
// ErrNoAccountID is returned when a request context is missing the account ID.
ErrNoAccountID = errors.New("no account ID in request context")
// ErrNoPeerConnection is returned when no embedded client exists for the account.
ErrNoPeerConnection = errors.New("no peer connection found")
// ErrClientStartFailed is returned when the embedded client fails to start.
ErrClientStartFailed = errors.New("client start failed")
// ErrTooManyInflight is returned when the per-backend in-flight limit is reached.
ErrTooManyInflight = errors.New("too many in-flight requests")
)
// domainInfo holds metadata about a registered domain.
type domainInfo struct {
serviceID string
}
type domainNotification struct {
domain domain.Domain
serviceID string
}
// clientEntry holds an embedded NetBird client and tracks which domains use it.
type clientEntry struct {
client *embed.Client
transport *http.Transport
domains map[domain.Domain]domainInfo
createdAt time.Time
started bool
// Per-backend in-flight limiting keyed by target host:port.
// TODO: clean up stale entries when backend targets change.
inflightMu sync.Mutex
inflightMap map[backendKey]chan struct{}
maxInflight int
}
// acquireInflight attempts to acquire an in-flight slot for the given backend.
// It returns a release function that must always be called, and true on success.
func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) {
noop := func() {}
if e.maxInflight <= 0 {
return noop, true
}
e.inflightMu.Lock()
sem, exists := e.inflightMap[backend]
if !exists {
sem = make(chan struct{}, e.maxInflight)
e.inflightMap[backend] = sem
}
e.inflightMu.Unlock()
select {
case sem <- struct{}{}:
return func() { <-sem }, true
default:
return noop, false
}
}
type statusNotifier interface {
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
}
type managementClient interface {
CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest, opts ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error)
}
// NetBird provides an http.RoundTripper implementation
// backed by underlying NetBird connections.
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
type NetBird struct {
mgmtAddr string
proxyID string
proxyAddr string
wgPort int
logger *log.Logger
mgmtClient managementClient
transportCfg transportConfig
clientsMux sync.RWMutex
clients map[types.AccountID]*clientEntry
initLogOnce sync.Once
statusNotifier statusNotifier
}
// ClientDebugInfo contains debug information about a client.
type ClientDebugInfo struct {
AccountID types.AccountID
DomainCount int
Domains domain.List
HasClient bool
CreatedAt time.Time
}
// accountIDContextKey is the context key for storing the account ID.
type accountIDContextKey struct{}
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
// one is created by authenticating with the management server using the provided token.
// Multiple domains can share the same client.
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
// Client already exists for this account, just register the domain
entry.domains[d] = domainInfo{serviceID: serviceID}
started := entry.started
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Debug("registered domain with existing client")
// If client is already started, notify this domain as connected immediately
if started && n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify status for existing client")
}
}
return nil
}
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
if err != nil {
n.clientsMux.Unlock()
return err
}
n.clients[accountID] = entry
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Info("created new client for account")
// Attempt to start the client in the background; if this fails we will
// retry on the first request via RoundTrip.
go n.runClientStartup(ctx, accountID, entry.client)
return nil
}
// createClientEntry generates a WireGuard keypair, authenticates with management,
// and creates an embedded NetBird client. Must be called with clientsMux held.
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
}).Debug("generating WireGuard keypair for new peer")
privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, fmt.Errorf("generate wireguard private key: %w", err)
}
publicKey := privateKey.PublicKey()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
"public_key": publicKey.String(),
}).Debug("authenticating new proxy peer with management")
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
ServiceId: serviceID,
AccountId: string(accountID),
Token: authToken,
WireguardPublicKey: publicKey.String(),
Cluster: n.proxyAddr,
})
if err != nil {
return nil, fmt.Errorf("authenticate proxy peer with management: %w", err)
}
if resp != nil && !resp.GetSuccess() {
errMsg := "unknown error"
if resp.ErrorMessage != nil {
errMsg = *resp.ErrorMessage
}
return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg)
}
n.logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
"public_key": publicKey.String(),
}).Info("proxy peer authenticated successfully with management")
n.initLogOnce.Do(func() {
if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil {
n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err)
}
})
// Create embedded NetBird client with the generated private key.
// The peer has already been created via CreateProxyPeer RPC with the public key.
client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.mgmtAddr,
PrivateKey: privateKey.String(),
LogLevel: log.WarnLevel.String(),
BlockInbound: true,
WireguardPort: &n.wgPort,
})
if err != nil {
return nil, fmt.Errorf("create netbird client: %w", err)
}
// Create a transport using the client dialer. We do this instead of using
// the client's HTTPClient to avoid issues with request validation that do
// not work with reverse proxied requests.
return &clientEntry{
client: client,
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
transport: &http.Transport{
DialContext: client.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: n.transportCfg.maxIdleConns,
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
IdleConnTimeout: n.transportCfg.idleConnTimeout,
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
WriteBufferSize: n.transportCfg.writeBufferSize,
ReadBufferSize: n.transportCfg.readBufferSize,
DisableCompression: n.transportCfg.disableCompression,
},
createdAt: time.Now(),
started: false,
inflightMap: make(map[backendKey]chan struct{}),
maxInflight: n.transportCfg.maxInflight,
}, nil
}
// runClientStartup starts the client and notifies registered domains on success.
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := client.Start(startCtx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request")
} else {
n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client")
}
return
}
// Mark client as started and collect domains to notify outside the lock.
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if exists {
entry.started = true
}
var domainsToNotify []domainNotification
if exists {
for dom, info := range entry.domains {
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
}
}
n.clientsMux.Unlock()
if n.statusNotifier == nil {
return
}
for _, dn := range domainsToNotify {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
}).WithError(err).Warn("failed to notify tunnel connection status")
} else {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": dn.domain,
}).Info("notified management about tunnel connection")
}
}
}
// RemovePeer unregisters a domain from an account. The client is only stopped
// when no domains are using it anymore.
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
n.clientsMux.Lock()
entry, exists := n.clients[accountID]
if !exists {
n.clientsMux.Unlock()
n.logger.WithField("account_id", accountID).Debug("remove peer: account not found")
return nil
}
// Get domain info before deleting
domInfo, domainExists := entry.domains[d]
if !domainExists {
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).Debug("remove peer: domain not registered")
return nil
}
delete(entry.domains, d)
// If there are still domains using this client, keep it running
if len(entry.domains) > 0 {
n.clientsMux.Unlock()
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
"remaining_domains": len(entry.domains),
}).Debug("unregistered domain, client still in use")
// Notify this domain as disconnected
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
return nil
}
// No more domains using this client, stop it
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).Info("stopping client, no more domains")
client := entry.client
transport := entry.transport
delete(n.clients, accountID)
n.clientsMux.Unlock()
// Notify disconnection before stopping
if n.statusNotifier != nil {
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
"domain": d,
}).WithError(err).Warn("failed to notify tunnel disconnection status")
}
}
transport.CloseIdleConnections()
if err := client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Warn("failed to stop netbird client")
}
return nil
}
// RoundTrip implements http.RoundTripper. It looks up the client for the account
// specified in the request context and uses it to dial the backend.
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
accountID := AccountIDFromContext(req.Context())
if accountID == "" {
return nil, ErrNoAccountID
}
// Copy references while holding lock, then unlock early to avoid blocking
// other requests during the potentially slow RoundTrip.
n.clientsMux.RLock()
entry, exists := n.clients[accountID]
if !exists {
n.clientsMux.RUnlock()
return nil, fmt.Errorf("%w for account: %s", ErrNoPeerConnection, accountID)
}
client := entry.client
transport := entry.transport
n.clientsMux.RUnlock()
release, ok := entry.acquireInflight(req.URL.Host)
defer release()
if !ok {
return nil, ErrTooManyInflight
}
// Attempt to start the client, if the client is already running then
// it will return an error that we ignore, if this hits a timeout then
// this request is unprocessable.
startCtx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
defer cancel()
if err := client.Start(startCtx); err != nil {
if !errors.Is(err, embed.ErrClientAlreadyStarted) {
return nil, fmt.Errorf("%w: %w", ErrClientStartFailed, err)
}
}
start := time.Now()
resp, err := transport.RoundTrip(req)
duration := time.Since(start)
if err != nil {
n.logger.Debugf("roundtrip: method=%s host=%s url=%s account=%s duration=%s err=%v",
req.Method, req.Host, req.URL.String(), accountID, duration.Truncate(time.Millisecond), err)
return nil, err
}
n.logger.Debugf("roundtrip: method=%s host=%s url=%s account=%s status=%d duration=%s",
req.Method, req.Host, req.URL.String(), accountID, resp.StatusCode, duration.Truncate(time.Millisecond))
return resp, nil
}
// StopAll stops all clients.
func (n *NetBird) StopAll(ctx context.Context) error {
n.clientsMux.Lock()
defer n.clientsMux.Unlock()
var merr *multierror.Error
for accountID, entry := range n.clients {
entry.transport.CloseIdleConnections()
if err := entry.client.Stop(ctx); err != nil {
n.logger.WithFields(log.Fields{
"account_id": accountID,
}).WithError(err).Warn("failed to stop netbird client during shutdown")
merr = multierror.Append(merr, err)
}
}
maps.Clear(n.clients)
return nberrors.FormatErrorOrNil(merr)
}
// HasClient returns true if there is a client for the given account.
func (n *NetBird) HasClient(accountID types.AccountID) bool {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
_, exists := n.clients[accountID]
return exists
}
// DomainCount returns the number of domains registered for the given account.
// Returns 0 if the account has no client.
func (n *NetBird) DomainCount(accountID types.AccountID) int {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
entry, exists := n.clients[accountID]
if !exists {
return 0
}
return len(entry.domains)
}
// ClientCount returns the total number of active clients.
func (n *NetBird) ClientCount() int {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
return len(n.clients)
}
// GetClient returns the embed.Client for the given account ID.
func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
entry, exists := n.clients[accountID]
if !exists {
return nil, false
}
return entry.client, true
}
// ListClientsForDebug returns information about all clients for debug purposes.
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
result := make(map[types.AccountID]ClientDebugInfo)
for accountID, entry := range n.clients {
domains := make(domain.List, 0, len(entry.domains))
for d := range entry.domains {
domains = append(domains, d)
}
result[accountID] = ClientDebugInfo{
AccountID: accountID,
DomainCount: len(entry.domains),
Domains: domains,
HasClient: entry.client != nil,
CreatedAt: entry.createdAt,
}
}
return result
}
// ListClientsForStartup returns all embed.Client instances for health checks.
func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client {
n.clientsMux.RLock()
defer n.clientsMux.RUnlock()
result := make(map[types.AccountID]*embed.Client)
for accountID, entry := range n.clients {
if entry.client != nil {
result[accountID] = entry.client
}
}
return result
}
// NewNetBird creates a new NetBird transport. Set wgPort to 0 for a random
// OS-assigned port. A fixed port only works with single-account deployments;
// multiple accounts will fail to bind the same port.
func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
if logger == nil {
logger = log.StandardLogger()
}
return &NetBird{
mgmtAddr: mgmtAddr,
proxyID: proxyID,
proxyAddr: proxyAddr,
wgPort: wgPort,
logger: logger,
clients: make(map[types.AccountID]*clientEntry),
statusNotifier: notifier,
mgmtClient: mgmtClient,
transportCfg: loadTransportConfig(logger),
}
}
// WithAccountID adds the account ID to the context.
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
return context.WithValue(ctx, accountIDContextKey{}, accountID)
}
// AccountIDFromContext retrieves the account ID from the context.
func AccountIDFromContext(ctx context.Context) types.AccountID {
v := ctx.Value(accountIDContextKey{})
if v == nil {
return ""
}
accountID, ok := v.(types.AccountID)
if !ok {
return ""
}
return accountID
}

View File

@@ -0,0 +1,107 @@
package roundtrip
import (
"crypto/rand"
"math/big"
"sync"
"testing"
"time"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
)
// Simple benchmark for comparison with AddPeer contention.
func BenchmarkHasClient(b *testing.B) {
// Knobs for dialling in:
initialClientCount := 100 // Size of initial peer map to generate.
nb := mockNetBird()
var target types.AccountID
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(initialClientCount)))
if err != nil {
b.Fatal(err)
}
for i := range initialClientCount {
id := types.AccountID(rand.Text())
if int64(i) == targetIndex.Int64() {
target = id
}
nb.clients[id] = &clientEntry{
domains: map[domain.Domain]domainInfo{
domain.Domain(rand.Text()): {
serviceID: rand.Text(),
},
},
createdAt: time.Now(),
started: true,
}
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
nb.HasClient(target)
}
})
b.StopTimer()
}
func BenchmarkHasClientDuringAddPeer(b *testing.B) {
// Knobs for dialling in:
initialClientCount := 100 // Size of initial peer map to generate.
addPeerWorkers := 5 // Number of workers to concurrently call AddPeer.
nb := mockNetBird()
// Add random client entries to the netbird instance.
// We're trying to test map lock contention, so starting with
// a populated map should help with this.
// Pick a random one to target for retrieval later.
var target types.AccountID
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(initialClientCount)))
if err != nil {
b.Fatal(err)
}
for i := range initialClientCount {
id := types.AccountID(rand.Text())
if int64(i) == targetIndex.Int64() {
target = id
}
nb.clients[id] = &clientEntry{
domains: map[domain.Domain]domainInfo{
domain.Domain(rand.Text()): {
serviceID: rand.Text(),
},
},
createdAt: time.Now(),
started: true,
}
}
// Launch workers that continuously call AddPeer with new random accountIDs.
var wg sync.WaitGroup
for range addPeerWorkers {
wg.Go(func() {
for {
if err := nb.AddPeer(b.Context(),
types.AccountID(rand.Text()),
domain.Domain(rand.Text()),
rand.Text(),
rand.Text()); err != nil {
b.Log(err)
}
}
})
}
// Benchmark calling HasClient during AddPeer contention.
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
nb.HasClient(target)
}
})
b.StopTimer()
}

View File

@@ -0,0 +1,328 @@
package roundtrip
import (
"context"
"net/http"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
)
type mockMgmtClient struct{}
func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
return &proto.CreateProxyPeerResponse{Success: true}, nil
}
type mockStatusNotifier struct {
mu sync.Mutex
statuses []statusCall
}
type statusCall struct {
accountID string
serviceID string
domain string
connected bool
}
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
return nil
}
func (m *mockStatusNotifier) calls() []statusCall {
m.mu.Lock()
defer m.mu.Unlock()
return append([]statusCall{}, m.statuses...)
}
// mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird {
return NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, nil, &mockMgmtClient{})
}
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Initially no client exists.
assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer")
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
// Add first domain - this should create a new client.
// Note: This will fail to actually connect since we use an invalid URL,
// but the client entry should still be created.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "should have client after AddPeer")
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
}
func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add first domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
assert.Equal(t, 1, nb.DomainCount(accountID))
// Add second domain for the same account - should reuse existing client.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
require.NoError(t, err)
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain")
// Add third domain.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
require.NoError(t, err)
assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain")
// Still only one client.
assert.True(t, nb.HasClient(accountID))
}
func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) {
nb := mockNetBird()
account1 := types.AccountID("account-1")
account2 := types.AccountID("account-2")
// Add domain for account 1.
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
// Add domain for account 2.
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2")
require.NoError(t, err)
// Both accounts should have their own clients.
assert.True(t, nb.HasClient(account1), "account1 should have client")
assert.True(t, nb.HasClient(account2), "account2 should have client")
assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1")
assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1")
}
func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add multiple domains.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
require.NoError(t, err)
assert.Equal(t, 3, nb.DomainCount(accountID))
// Remove one domain - client should remain.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain")
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2")
// Remove another domain - client should still remain.
err = nb.RemovePeer(context.Background(), accountID, "domain2.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain")
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
}
func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add single domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
// Remove the only domain - client should be removed.
// Note: Stop() may fail since the client never actually connected,
// but the entry should still be removed from the map.
_ = nb.RemovePeer(context.Background(), accountID, "domain1.test")
// After removing all domains, client should be gone.
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain")
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
}
func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("nonexistent-account")
// Removing from non-existent account should not error.
err := nb.RemovePeer(context.Background(), accountID, "domain1.test")
assert.NoError(t, err, "removing from non-existent account should not error")
}
func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("account-1")
// Add one domain.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
require.NoError(t, err)
// Remove non-existent domain - should not affect existing domain.
err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test"))
require.NoError(t, err)
// Original domain should still be registered.
assert.True(t, nb.HasClient(accountID))
assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain")
}
func TestWithAccountID_AndAccountIDFromContext(t *testing.T) {
ctx := context.Background()
accountID := types.AccountID("test-account")
// Initially no account ID in context.
retrieved := AccountIDFromContext(ctx)
assert.True(t, retrieved == "", "should be empty when not set")
// Add account ID to context.
ctx = WithAccountID(ctx, accountID)
retrieved = AccountIDFromContext(ctx)
assert.Equal(t, accountID, retrieved, "should retrieve the same account ID")
}
func TestAccountIDFromContext_ReturnsEmptyForWrongType(t *testing.T) {
// Create context with wrong type for account ID key.
ctx := context.WithValue(context.Background(), accountIDContextKey{}, "wrong-type-string")
retrieved := AccountIDFromContext(ctx)
assert.True(t, retrieved == "", "should return empty for wrong type")
}
func TestNetBird_StopAll_StopsAllClients(t *testing.T) {
nb := mockNetBird()
account1 := types.AccountID("account-1")
account2 := types.AccountID("account-2")
account3 := types.AccountID("account-3")
// Add domains for multiple accounts.
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3")
require.NoError(t, err)
assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients")
// Stop all clients.
// Note: StopAll may return errors since clients never actually connected,
// but the clients should still be removed from the map.
_ = nb.StopAll(context.Background())
assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll")
assert.False(t, nb.HasClient(account1), "account1 should not have client")
assert.False(t, nb.HasClient(account2), "account2 should not have client")
assert.False(t, nb.HasClient(account3), "account3 should not have client")
}
func TestNetBird_ClientCount(t *testing.T) {
nb := mockNetBird()
assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients")
// Add clients for different accounts.
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1")
require.NoError(t, err)
assert.Equal(t, 1, nb.ClientCount())
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2")
require.NoError(t, err)
assert.Equal(t, 2, nb.ClientCount())
// Adding domain to existing account should not increase count.
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b")
require.NoError(t, err)
assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count")
}
func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) {
nb := mockNetBird()
// Create a request without account ID in context.
req, err := http.NewRequest("GET", "http://example.com/", nil)
require.NoError(t, err)
// RoundTrip should fail because no account ID in context.
_, err = nb.RoundTrip(req) //nolint:bodyclose
require.ErrorIs(t, err, ErrNoAccountID)
}
func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
nb := mockNetBird()
accountID := types.AccountID("nonexistent-account")
// Create a request with account ID but no client exists.
req, err := http.NewRequest("GET", "http://example.com/", nil)
require.NoError(t, err)
req = req.WithContext(WithAccountID(req.Context(), accountID))
// RoundTrip should fail because no client for this account.
_, err = nb.RoundTrip(req) //nolint:bodyclose // Error case, no response body
assert.Error(t, err)
assert.Contains(t, err.Error(), "no peer connection found for account")
}
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
// Add first domain — creates a new client entry.
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
require.NoError(t, err)
// Manually mark client as started to simulate background startup completing.
nb.clientsMux.Lock()
nb.clients[accountID].started = true
nb.clientsMux.Unlock()
// Add second domain — should notify immediately since client is already started.
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
require.NoError(t, err)
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, string(accountID), calls[0].accountID)
assert.Equal(t, "svc-2", calls[0].serviceID)
assert.Equal(t, "domain2.test", calls[0].domain)
assert.True(t, calls[0].connected)
}
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
require.NoError(t, err)
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
require.NoError(t, err)
// Remove one domain — client stays, but disconnection notification fires.
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
require.NoError(t, err)
assert.True(t, nb.HasClient(accountID))
calls := notifier.calls()
require.Len(t, calls, 1)
assert.Equal(t, "domain1.test", calls[0].domain)
assert.False(t, calls[0].connected)
}

View File

@@ -0,0 +1,152 @@
package roundtrip
import (
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
// Environment variable names for tuning the backend HTTP transport.
const (
EnvMaxIdleConns = "NB_PROXY_MAX_IDLE_CONNS"
EnvMaxIdleConnsPerHost = "NB_PROXY_MAX_IDLE_CONNS_PER_HOST"
EnvMaxConnsPerHost = "NB_PROXY_MAX_CONNS_PER_HOST"
EnvIdleConnTimeout = "NB_PROXY_IDLE_CONN_TIMEOUT"
EnvTLSHandshakeTimeout = "NB_PROXY_TLS_HANDSHAKE_TIMEOUT"
EnvExpectContinueTimeout = "NB_PROXY_EXPECT_CONTINUE_TIMEOUT"
EnvResponseHeaderTimeout = "NB_PROXY_RESPONSE_HEADER_TIMEOUT"
EnvWriteBufferSize = "NB_PROXY_WRITE_BUFFER_SIZE"
EnvReadBufferSize = "NB_PROXY_READ_BUFFER_SIZE"
EnvDisableCompression = "NB_PROXY_DISABLE_COMPRESSION"
EnvMaxInflight = "NB_PROXY_MAX_INFLIGHT"
)
// transportConfig holds tunable parameters for the per-account HTTP transport.
type transportConfig struct {
maxIdleConns int
maxIdleConnsPerHost int
maxConnsPerHost int
idleConnTimeout time.Duration
tlsHandshakeTimeout time.Duration
expectContinueTimeout time.Duration
responseHeaderTimeout time.Duration
writeBufferSize int
readBufferSize int
disableCompression bool
// maxInflight limits per-backend concurrent requests. 0 means unlimited.
maxInflight int
}
func defaultTransportConfig() transportConfig {
return transportConfig{
maxIdleConns: 100,
maxIdleConnsPerHost: 100,
maxConnsPerHost: 0, // unlimited
idleConnTimeout: 90 * time.Second,
tlsHandshakeTimeout: 10 * time.Second,
expectContinueTimeout: 1 * time.Second,
}
}
func loadTransportConfig(logger *log.Logger) transportConfig {
cfg := defaultTransportConfig()
if v, ok := envInt(EnvMaxIdleConns, logger); ok {
cfg.maxIdleConns = v
}
if v, ok := envInt(EnvMaxIdleConnsPerHost, logger); ok {
cfg.maxIdleConnsPerHost = v
}
if v, ok := envInt(EnvMaxConnsPerHost, logger); ok {
cfg.maxConnsPerHost = v
}
if v, ok := envDuration(EnvIdleConnTimeout, logger); ok {
cfg.idleConnTimeout = v
}
if v, ok := envDuration(EnvTLSHandshakeTimeout, logger); ok {
cfg.tlsHandshakeTimeout = v
}
if v, ok := envDuration(EnvExpectContinueTimeout, logger); ok {
cfg.expectContinueTimeout = v
}
if v, ok := envDuration(EnvResponseHeaderTimeout, logger); ok {
cfg.responseHeaderTimeout = v
}
if v, ok := envInt(EnvWriteBufferSize, logger); ok {
cfg.writeBufferSize = v
}
if v, ok := envInt(EnvReadBufferSize, logger); ok {
cfg.readBufferSize = v
}
if v, ok := envBool(EnvDisableCompression, logger); ok {
cfg.disableCompression = v
}
if v, ok := envInt(EnvMaxInflight, logger); ok {
cfg.maxInflight = v
}
logger.WithFields(log.Fields{
"max_idle_conns": cfg.maxIdleConns,
"max_idle_conns_per_host": cfg.maxIdleConnsPerHost,
"max_conns_per_host": cfg.maxConnsPerHost,
"idle_conn_timeout": cfg.idleConnTimeout,
"tls_handshake_timeout": cfg.tlsHandshakeTimeout,
"expect_continue_timeout": cfg.expectContinueTimeout,
"response_header_timeout": cfg.responseHeaderTimeout,
"write_buffer_size": cfg.writeBufferSize,
"read_buffer_size": cfg.readBufferSize,
"disable_compression": cfg.disableCompression,
"max_inflight": cfg.maxInflight,
}).Debug("backend transport configuration")
return cfg
}
func envInt(key string, logger *log.Logger) (int, bool) {
s := os.Getenv(key)
if s == "" {
return 0, false
}
v, err := strconv.Atoi(s)
if err != nil {
logger.Warnf("failed to parse %s=%q as int: %v", key, s, err)
return 0, false
}
if v < 0 {
logger.Warnf("ignoring negative value for %s=%d", key, v)
return 0, false
}
return v, true
}
func envDuration(key string, logger *log.Logger) (time.Duration, bool) {
s := os.Getenv(key)
if s == "" {
return 0, false
}
v, err := time.ParseDuration(s)
if err != nil {
logger.Warnf("failed to parse %s=%q as duration: %v", key, s, err)
return 0, false
}
if v < 0 {
logger.Warnf("ignoring negative value for %s=%s", key, v)
return 0, false
}
return v, true
}
func envBool(key string, logger *log.Logger) (bool, bool) {
s := os.Getenv(key)
if s == "" {
return false, false
}
v, err := strconv.ParseBool(s)
if err != nil {
logger.Warnf("failed to parse %s=%q as bool: %v", key, s, err)
return false, false
}
return v, true
}

View File

@@ -0,0 +1,5 @@
// Package types defines common types used across the proxy package.
package types
// AccountID represents a unique identifier for a NetBird account.
type AccountID string