mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management, reverse proxy] Add reverse proxy feature (#5291)
* implement reverse proxy --------- Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk> Co-authored-by: mlsmaycon <mlsmaycon@gmail.com> Co-authored-by: Eduard Gert <kontakt@eduardgert.de> Co-authored-by: Viktor Liu <viktor@netbird.io> Co-authored-by: Diego Noguês <diego.sure@gmail.com> Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com> Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com> Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
This commit is contained in:
105
proxy/internal/accesslog/logger.go
Normal file
105
proxy/internal/accesslog/logger.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type gRPCClient interface {
|
||||
SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error)
|
||||
}
|
||||
|
||||
// Logger sends access log entries to the management server via gRPC.
|
||||
type Logger struct {
|
||||
client gRPCClient
|
||||
logger *log.Logger
|
||||
trustedProxies []netip.Prefix
|
||||
}
|
||||
|
||||
// NewLogger creates a new access log Logger. The trustedProxies parameter
|
||||
// configures which upstream proxy IP ranges are trusted for extracting
|
||||
// the real client IP from X-Forwarded-For headers.
|
||||
func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Prefix) *Logger {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &Logger{
|
||||
client: client,
|
||||
logger: logger,
|
||||
trustedProxies: trustedProxies,
|
||||
}
|
||||
}
|
||||
|
||||
type logEntry struct {
|
||||
ID string
|
||||
AccountID string
|
||||
ServiceId string
|
||||
Host string
|
||||
Path string
|
||||
DurationMs int64
|
||||
Method string
|
||||
ResponseCode int32
|
||||
SourceIp string
|
||||
AuthMechanism string
|
||||
UserId string
|
||||
AuthSuccess bool
|
||||
}
|
||||
|
||||
func (l *Logger) log(ctx context.Context, entry logEntry) {
|
||||
// Fire off the log request in a separate routine.
|
||||
// This increases the possibility of losing a log message
|
||||
// (although it should still get logged in the event of an error),
|
||||
// but it will reduce latency returning the request in the
|
||||
// middleware.
|
||||
// There is also a chance that log messages will arrive at
|
||||
// the server out of order; however, the timestamp should
|
||||
// allow for resolving that on the server.
|
||||
now := timestamppb.Now() // Grab the timestamp before launching the goroutine to try to prevent weird timing issues. This is probably unnecessary.
|
||||
go func() {
|
||||
logCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if entry.AuthMechanism != auth.MethodOIDC.String() {
|
||||
entry.UserId = ""
|
||||
}
|
||||
if _, err := l.client.SendAccessLog(logCtx, &proto.SendAccessLogRequest{
|
||||
Log: &proto.AccessLog{
|
||||
LogId: entry.ID,
|
||||
AccountId: entry.AccountID,
|
||||
Timestamp: now,
|
||||
ServiceId: entry.ServiceId,
|
||||
Host: entry.Host,
|
||||
Path: entry.Path,
|
||||
DurationMs: entry.DurationMs,
|
||||
Method: entry.Method,
|
||||
ResponseCode: entry.ResponseCode,
|
||||
SourceIp: entry.SourceIp,
|
||||
AuthMechanism: entry.AuthMechanism,
|
||||
UserId: entry.UserId,
|
||||
AuthSuccess: entry.AuthSuccess,
|
||||
},
|
||||
}); err != nil {
|
||||
// If it fails to send on the gRPC connection, then at least log it to the error log.
|
||||
l.logger.WithFields(log.Fields{
|
||||
"service_id": entry.ServiceId,
|
||||
"host": entry.Host,
|
||||
"path": entry.Path,
|
||||
"duration": entry.DurationMs,
|
||||
"method": entry.Method,
|
||||
"response_code": entry.ResponseCode,
|
||||
"source_ip": entry.SourceIp,
|
||||
"auth_mechanism": entry.AuthMechanism,
|
||||
"user_id": entry.UserId,
|
||||
"auth_success": entry.AuthSuccess,
|
||||
"error": err,
|
||||
}).Error("Error sending access log on gRPC connection")
|
||||
}
|
||||
}()
|
||||
}
|
||||
74
proxy/internal/accesslog/middleware.go
Normal file
74
proxy/internal/accesslog/middleware.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip logging for internal proxy assets (CSS, JS, etc.)
|
||||
if strings.HasPrefix(r.URL.Path, web.PathPrefix+"/") {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate request ID early so it can be used by error pages and log correlation.
|
||||
requestID := xid.New().String()
|
||||
|
||||
l.logger.Debugf("request: request_id=%s method=%s host=%s path=%s", requestID, r.Method, r.Host, r.URL.Path)
|
||||
|
||||
// Use a response writer wrapper so we can access the status code later.
|
||||
sw := &statusWriter{
|
||||
w: w,
|
||||
status: http.StatusOK,
|
||||
}
|
||||
|
||||
// Resolve the source IP using trusted proxy configuration before passing
|
||||
// the request on, as the proxy will modify forwarding headers.
|
||||
sourceIp := extractSourceIP(r, l.trustedProxies)
|
||||
|
||||
// Create a mutable struct to capture data from downstream handlers.
|
||||
// We pass a pointer in the context - the pointer itself flows down immutably,
|
||||
// but the struct it points to can be mutated by inner handlers.
|
||||
capturedData := &proxy.CapturedData{RequestID: requestID}
|
||||
capturedData.SetClientIP(sourceIp)
|
||||
ctx := proxy.WithCapturedData(r.Context(), capturedData)
|
||||
|
||||
start := time.Now()
|
||||
next.ServeHTTP(sw, r.WithContext(ctx))
|
||||
duration := time.Since(start)
|
||||
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
// Fallback to just using the full host value.
|
||||
host = r.Host
|
||||
}
|
||||
|
||||
entry := logEntry{
|
||||
ID: requestID,
|
||||
ServiceId: capturedData.GetServiceId(),
|
||||
AccountID: string(capturedData.GetAccountId()),
|
||||
Host: host,
|
||||
Path: r.URL.Path,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Method: r.Method,
|
||||
ResponseCode: int32(sw.status),
|
||||
SourceIp: sourceIp,
|
||||
AuthMechanism: capturedData.GetAuthMethod(),
|
||||
UserId: capturedData.GetUserID(),
|
||||
AuthSuccess: sw.status != http.StatusUnauthorized && sw.status != http.StatusForbidden,
|
||||
}
|
||||
l.logger.Debugf("response: request_id=%s method=%s host=%s path=%s status=%d duration=%dms source=%s origin=%s service=%s account=%s",
|
||||
requestID, r.Method, host, r.URL.Path, sw.status, duration.Milliseconds(), sourceIp, capturedData.GetOrigin(), capturedData.GetServiceId(), capturedData.GetAccountId())
|
||||
|
||||
l.log(r.Context(), entry)
|
||||
})
|
||||
}
|
||||
16
proxy/internal/accesslog/requestip.go
Normal file
16
proxy/internal/accesslog/requestip.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
)
|
||||
|
||||
// extractSourceIP resolves the real client IP from the request using trusted
|
||||
// proxy configuration. When trustedProxies is non-empty and the direct
|
||||
// connection is from a trusted source, it walks X-Forwarded-For right-to-left
|
||||
// skipping trusted IPs. Otherwise it returns RemoteAddr directly.
|
||||
func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) string {
|
||||
return proxy.ResolveClientIP(r.RemoteAddr, r.Header.Get("X-Forwarded-For"), trustedProxies)
|
||||
}
|
||||
26
proxy/internal/accesslog/statuswriter.go
Normal file
26
proxy/internal/accesslog/statuswriter.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// statusWriter is a simple wrapper around an http.ResponseWriter
|
||||
// that captures the setting of the status code via the WriteHeader
|
||||
// function and stores it so that it can be retrieved later.
|
||||
type statusWriter struct {
|
||||
w http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *statusWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *statusWriter) Write(data []byte) (int, error) {
|
||||
return w.w.Write(data)
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.w.WriteHeader(status)
|
||||
}
|
||||
102
proxy/internal/acme/locker.go
Normal file
102
proxy/internal/acme/locker.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/flock"
|
||||
"github.com/netbirdio/netbird/proxy/internal/k8s"
|
||||
)
|
||||
|
||||
// certLocker provides distributed mutual exclusion for certificate operations.
|
||||
// Implementations must be safe for concurrent use from multiple goroutines.
|
||||
type certLocker interface {
|
||||
// Lock acquires an exclusive lock for the given domain.
|
||||
// It blocks until the lock is acquired, the context is cancelled, or an
|
||||
// unrecoverable error occurs. The returned function releases the lock;
|
||||
// callers must call it exactly once when the critical section is complete.
|
||||
Lock(ctx context.Context, domain string) (unlock func(), err error)
|
||||
}
|
||||
|
||||
// CertLockMethod controls how ACME certificate locks are coordinated.
|
||||
type CertLockMethod string
|
||||
|
||||
const (
|
||||
// CertLockAuto detects the environment and selects k8s-lease if running
|
||||
// in a Kubernetes pod, otherwise flock.
|
||||
CertLockAuto CertLockMethod = "auto"
|
||||
// CertLockFlock uses advisory file locks via flock(2).
|
||||
CertLockFlock CertLockMethod = "flock"
|
||||
// CertLockK8sLease uses Kubernetes coordination Leases.
|
||||
CertLockK8sLease CertLockMethod = "k8s-lease"
|
||||
)
|
||||
|
||||
func newCertLocker(method CertLockMethod, certDir string, logger *log.Logger) certLocker {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
|
||||
if method == "" || method == CertLockAuto {
|
||||
if k8s.InCluster() {
|
||||
method = CertLockK8sLease
|
||||
} else {
|
||||
method = CertLockFlock
|
||||
}
|
||||
logger.Infof("auto-detected cert lock method: %s", method)
|
||||
}
|
||||
|
||||
switch method {
|
||||
case CertLockK8sLease:
|
||||
locker, err := newK8sLeaseLocker(logger)
|
||||
if err != nil {
|
||||
logger.Warnf("create k8s lease locker, falling back to flock: %v", err)
|
||||
return newFlockLocker(certDir, logger)
|
||||
}
|
||||
logger.Infof("using k8s lease locker in namespace %s", locker.client.Namespace())
|
||||
return locker
|
||||
default:
|
||||
logger.Infof("using flock cert locker in %s", certDir)
|
||||
return newFlockLocker(certDir, logger)
|
||||
}
|
||||
}
|
||||
|
||||
type flockLocker struct {
|
||||
certDir string
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func newFlockLocker(certDir string, logger *log.Logger) *flockLocker {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &flockLocker{certDir: certDir, logger: logger}
|
||||
}
|
||||
|
||||
// Lock acquires an advisory file lock for the given domain.
|
||||
func (l *flockLocker) Lock(ctx context.Context, domain string) (func(), error) {
|
||||
lockPath := filepath.Join(l.certDir, domain+".lock")
|
||||
lockFile, err := flock.Lock(ctx, lockPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// nil lockFile means locking is not supported (non-unix).
|
||||
if lockFile == nil {
|
||||
return func() { /* no-op: locking unsupported on this platform */ }, nil
|
||||
}
|
||||
|
||||
return func() {
|
||||
if err := flock.Unlock(lockFile); err != nil {
|
||||
l.logger.Debugf("release cert lock for domain %q: %v", domain, err)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type noopLocker struct{}
|
||||
|
||||
// Lock is a no-op that always succeeds immediately.
|
||||
func (noopLocker) Lock(context.Context, string) (func(), error) {
|
||||
return func() { /* no-op: locker disabled */ }, nil
|
||||
}
|
||||
197
proxy/internal/acme/locker_k8s.go
Normal file
197
proxy/internal/acme/locker_k8s.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/k8s"
|
||||
)
|
||||
|
||||
const (
|
||||
// leaseDurationSec is the Kubernetes Lease TTL. If the holder crashes without
|
||||
// releasing the lock, other replicas must wait this long before taking over.
|
||||
// This is intentionally generous: in the worst case two replicas may both
|
||||
// issue an ACME request for the same domain, which is harmless (the CA
|
||||
// deduplicates and the cache converges).
|
||||
leaseDurationSec = 300
|
||||
retryBaseBackoff = 500 * time.Millisecond
|
||||
retryMaxBackoff = 10 * time.Second
|
||||
)
|
||||
|
||||
type k8sLeaseLocker struct {
|
||||
client *k8s.LeaseClient
|
||||
identity string
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func newK8sLeaseLocker(logger *log.Logger) (*k8sLeaseLocker, error) {
|
||||
client, err := k8s.NewLeaseClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create k8s lease client: %w", err)
|
||||
}
|
||||
|
||||
identity, err := os.Hostname()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get hostname: %w", err)
|
||||
}
|
||||
|
||||
return &k8sLeaseLocker{
|
||||
client: client,
|
||||
identity: identity,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lock acquires a Kubernetes Lease for the given domain using optimistic
|
||||
// concurrency. It retries with exponential backoff until the lease is
|
||||
// acquired or the context is cancelled.
|
||||
func (l *k8sLeaseLocker) Lock(ctx context.Context, domain string) (func(), error) {
|
||||
leaseName := k8s.LeaseNameForDomain(domain)
|
||||
backoff := retryBaseBackoff
|
||||
|
||||
for {
|
||||
acquired, err := l.tryAcquire(ctx, leaseName, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acquire lease %s for %q: %w", leaseName, domain, err)
|
||||
}
|
||||
if acquired {
|
||||
l.logger.Debugf("k8s lease %s acquired for domain %q", leaseName, domain)
|
||||
return l.unlockFunc(leaseName, domain), nil
|
||||
}
|
||||
|
||||
l.logger.Debugf("k8s lease %s held by another replica, retrying in %s", leaseName, backoff)
|
||||
|
||||
timer := time.NewTimer(backoff)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil, ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
backoff *= 2
|
||||
if backoff > retryMaxBackoff {
|
||||
backoff = retryMaxBackoff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryAcquire attempts to create or take over a Lease. Returns (true, nil)
|
||||
// on success, (false, nil) if the lease is held and not stale, or an error.
|
||||
func (l *k8sLeaseLocker) tryAcquire(ctx context.Context, name, domain string) (bool, error) {
|
||||
existing, err := l.client.Get(ctx, name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
now := k8s.MicroTime{Time: time.Now().UTC()}
|
||||
dur := int32(leaseDurationSec)
|
||||
|
||||
if existing == nil {
|
||||
lease := &k8s.Lease{
|
||||
Metadata: k8s.LeaseMetadata{
|
||||
Name: name,
|
||||
Annotations: map[string]string{
|
||||
"netbird.io/domain": domain,
|
||||
},
|
||||
},
|
||||
Spec: k8s.LeaseSpec{
|
||||
HolderIdentity: &l.identity,
|
||||
LeaseDurationSeconds: &dur,
|
||||
AcquireTime: &now,
|
||||
RenewTime: &now,
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := l.client.Create(ctx, lease); errors.Is(err, k8s.ErrConflict) {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if !l.canTakeover(existing) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
existing.Spec.HolderIdentity = &l.identity
|
||||
existing.Spec.LeaseDurationSeconds = &dur
|
||||
existing.Spec.AcquireTime = &now
|
||||
existing.Spec.RenewTime = &now
|
||||
|
||||
if _, err := l.client.Update(ctx, existing); errors.Is(err, k8s.ErrConflict) {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// canTakeover returns true if the lease is free (no holder) or stale
|
||||
// (renewTime + leaseDuration has passed).
|
||||
func (l *k8sLeaseLocker) canTakeover(lease *k8s.Lease) bool {
|
||||
holder := lease.Spec.HolderIdentity
|
||||
if holder == nil || *holder == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// We already hold it (e.g. from a previous crashed attempt).
|
||||
if *holder == l.identity {
|
||||
return true
|
||||
}
|
||||
|
||||
if lease.Spec.RenewTime == nil || lease.Spec.LeaseDurationSeconds == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
expiry := lease.Spec.RenewTime.Add(time.Duration(*lease.Spec.LeaseDurationSeconds) * time.Second)
|
||||
if time.Now().After(expiry) {
|
||||
l.logger.Infof("k8s lease %s held by %q is stale (expired %s ago), taking over",
|
||||
lease.Metadata.Name, *holder, time.Since(expiry).Round(time.Second))
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// unlockFunc returns a closure that releases the lease by clearing the holder.
|
||||
func (l *k8sLeaseLocker) unlockFunc(name, domain string) func() {
|
||||
return func() {
|
||||
// Use a fresh context: the parent may already be cancelled.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Re-GET to get current resourceVersion (ours may be stale if
|
||||
// the lock was held for a long time and something updated it).
|
||||
current, err := l.client.Get(ctx, name)
|
||||
if err != nil {
|
||||
l.logger.Debugf("release k8s lease %s for %q: get: %v", name, domain, err)
|
||||
return
|
||||
}
|
||||
if current == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Only clear if we're still the holder.
|
||||
if current.Spec.HolderIdentity == nil || *current.Spec.HolderIdentity != l.identity {
|
||||
l.logger.Debugf("k8s lease %s for %q: holder changed to %v, skip release",
|
||||
name, domain, current.Spec.HolderIdentity)
|
||||
return
|
||||
}
|
||||
|
||||
empty := ""
|
||||
current.Spec.HolderIdentity = &empty
|
||||
current.Spec.AcquireTime = nil
|
||||
current.Spec.RenewTime = nil
|
||||
|
||||
if _, err := l.client.Update(ctx, current); err != nil {
|
||||
l.logger.Debugf("release k8s lease %s for %q: update: %v", name, domain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
65
proxy/internal/acme/locker_test.go
Normal file
65
proxy/internal/acme/locker_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFlockLockerRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
locker := newFlockLocker(dir, nil)
|
||||
|
||||
unlock, err := locker.Lock(context.Background(), "example.com")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, unlock)
|
||||
|
||||
// Lock file should exist.
|
||||
assert.FileExists(t, filepath.Join(dir, "example.com.lock"))
|
||||
|
||||
unlock()
|
||||
}
|
||||
|
||||
func TestNoopLocker(t *testing.T) {
|
||||
locker := noopLocker{}
|
||||
unlock, err := locker.Lock(context.Background(), "example.com")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, unlock)
|
||||
unlock()
|
||||
}
|
||||
|
||||
func TestNewCertLockerDefaultsToFlock(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// t.Setenv registers cleanup to restore the original value.
|
||||
// os.Unsetenv is needed because the production code uses LookupEnv,
|
||||
// which distinguishes "empty" from "not set".
|
||||
t.Setenv("KUBERNETES_SERVICE_HOST", "")
|
||||
os.Unsetenv("KUBERNETES_SERVICE_HOST")
|
||||
locker := newCertLocker(CertLockAuto, dir, nil)
|
||||
|
||||
_, ok := locker.(*flockLocker)
|
||||
assert.True(t, ok, "auto without k8s env should select flockLocker")
|
||||
}
|
||||
|
||||
func TestNewCertLockerExplicitFlock(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
locker := newCertLocker(CertLockFlock, dir, nil)
|
||||
|
||||
_, ok := locker.(*flockLocker)
|
||||
assert.True(t, ok, "explicit flock should select flockLocker")
|
||||
}
|
||||
|
||||
func TestNewCertLockerK8sFallsBackToFlock(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// k8s-lease without SA files should fall back to flock.
|
||||
locker := newCertLocker(CertLockK8sLease, dir, nil)
|
||||
|
||||
_, ok := locker.(*flockLocker)
|
||||
assert.True(t, ok, "k8s-lease without SA should fall back to flockLocker")
|
||||
}
|
||||
336
proxy/internal/acme/manager.go
Normal file
336
proxy/internal/acme/manager.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// OID for the SCT list extension (1.3.6.1.4.1.11129.2.4.2)
|
||||
var oidSCTList = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 2}
|
||||
|
||||
type certificateNotifier interface {
|
||||
NotifyCertificateIssued(ctx context.Context, accountID, serviceID, domain string) error
|
||||
}
|
||||
|
||||
type domainState int
|
||||
|
||||
const (
|
||||
domainPending domainState = iota
|
||||
domainReady
|
||||
domainFailed
|
||||
)
|
||||
|
||||
type domainInfo struct {
|
||||
accountID string
|
||||
serviceID string
|
||||
state domainState
|
||||
err string
|
||||
}
|
||||
|
||||
// Manager wraps autocert.Manager with domain tracking and cross-replica
|
||||
// coordination via a pluggable locking strategy. The locker prevents
|
||||
// duplicate ACME requests when multiple replicas share a certificate cache.
|
||||
type Manager struct {
|
||||
*autocert.Manager
|
||||
|
||||
certDir string
|
||||
locker certLocker
|
||||
mu sync.RWMutex
|
||||
domains map[domain.Domain]*domainInfo
|
||||
|
||||
certNotifier certificateNotifier
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewManager creates a new ACME certificate manager. The certDir is used
|
||||
// for caching certificates. The lockMethod controls cross-replica
|
||||
// coordination strategy (see CertLockMethod constants).
|
||||
func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
mgr := &Manager{
|
||||
certDir: certDir,
|
||||
locker: newCertLocker(lockMethod, certDir, logger),
|
||||
domains: make(map[domain.Domain]*domainInfo),
|
||||
certNotifier: notifier,
|
||||
logger: logger,
|
||||
}
|
||||
mgr.Manager = &autocert.Manager{
|
||||
Prompt: autocert.AcceptTOS,
|
||||
HostPolicy: mgr.hostPolicy,
|
||||
Cache: autocert.DirCache(certDir),
|
||||
Client: &acme.Client{
|
||||
DirectoryURL: acmeURL,
|
||||
},
|
||||
}
|
||||
return mgr
|
||||
}
|
||||
|
||||
func (mgr *Manager) hostPolicy(_ context.Context, host string) error {
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
mgr.mu.RLock()
|
||||
_, exists := mgr.domains[domain.Domain(host)]
|
||||
mgr.mu.RUnlock()
|
||||
if !exists {
|
||||
return fmt.Errorf("unknown domain %q", host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDomain registers a domain for ACME certificate prefetching.
|
||||
func (mgr *Manager) AddDomain(d domain.Domain, accountID, serviceID string) {
|
||||
mgr.mu.Lock()
|
||||
mgr.domains[d] = &domainInfo{
|
||||
accountID: accountID,
|
||||
serviceID: serviceID,
|
||||
state: domainPending,
|
||||
}
|
||||
mgr.mu.Unlock()
|
||||
|
||||
go mgr.prefetchCertificate(d)
|
||||
}
|
||||
|
||||
// prefetchCertificate proactively triggers certificate generation for a domain.
|
||||
// It acquires a distributed lock to prevent multiple replicas from issuing
|
||||
// duplicate ACME requests. The second replica will block until the first
|
||||
// finishes, then find the certificate in the cache.
|
||||
func (mgr *Manager) prefetchCertificate(d domain.Domain) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
name := d.PunycodeString()
|
||||
|
||||
mgr.logger.Infof("acquiring cert lock for domain %q", name)
|
||||
lockStart := time.Now()
|
||||
unlock, err := mgr.locker.Lock(ctx, name)
|
||||
if err != nil {
|
||||
mgr.logger.Warnf("acquire cert lock for domain %q, proceeding without lock: %v", name, err)
|
||||
} else {
|
||||
mgr.logger.Infof("acquired cert lock for domain %q in %s", name, time.Since(lockStart))
|
||||
defer unlock()
|
||||
}
|
||||
|
||||
hello := &tls.ClientHelloInfo{
|
||||
ServerName: name,
|
||||
Conn: &dummyConn{ctx: ctx},
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
cert, err := mgr.GetCertificate(hello)
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
mgr.logger.Warnf("prefetch certificate for domain %q: %v", name, err)
|
||||
mgr.setDomainState(d, domainFailed, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
mgr.setDomainState(d, domainReady, "")
|
||||
|
||||
now := time.Now()
|
||||
if cert != nil && cert.Leaf != nil {
|
||||
leaf := cert.Leaf
|
||||
mgr.logger.Infof("certificate for domain %q ready in %s: serial=%s SANs=%v notBefore=%s, notAfter=%s, now=%s",
|
||||
name, elapsed.Round(time.Millisecond),
|
||||
leaf.SerialNumber.Text(16),
|
||||
leaf.DNSNames,
|
||||
leaf.NotBefore.UTC().Format(time.RFC3339),
|
||||
leaf.NotAfter.UTC().Format(time.RFC3339),
|
||||
now.UTC().Format(time.RFC3339),
|
||||
)
|
||||
mgr.logCertificateDetails(name, leaf, now)
|
||||
} else {
|
||||
mgr.logger.Infof("certificate for domain %q ready in %s", name, elapsed.Round(time.Millisecond))
|
||||
}
|
||||
|
||||
mgr.mu.RLock()
|
||||
info := mgr.domains[d]
|
||||
mgr.mu.RUnlock()
|
||||
|
||||
if info != nil && mgr.certNotifier != nil {
|
||||
if err := mgr.certNotifier.NotifyCertificateIssued(ctx, info.accountID, info.serviceID, name); err != nil {
|
||||
mgr.logger.Warnf("notify certificate ready for domain %q: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *Manager) setDomainState(d domain.Domain, state domainState, errMsg string) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
if info, ok := mgr.domains[d]; ok {
|
||||
info.state = state
|
||||
info.err = errMsg
|
||||
}
|
||||
}
|
||||
|
||||
// logCertificateDetails logs certificate validity and SCT timestamps.
|
||||
func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate, now time.Time) {
|
||||
if cert.NotBefore.After(now) {
|
||||
mgr.logger.Warnf("certificate for %q NotBefore is in the future by %v", domain, cert.NotBefore.Sub(now))
|
||||
}
|
||||
|
||||
sctTimestamps := mgr.parseSCTTimestamps(cert)
|
||||
if len(sctTimestamps) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for i, sctTime := range sctTimestamps {
|
||||
if sctTime.After(now) {
|
||||
mgr.logger.Warnf("certificate for %q SCT[%d] timestamp is in the future: %v (by %v)",
|
||||
domain, i, sctTime.UTC(), sctTime.Sub(now))
|
||||
} else {
|
||||
mgr.logger.Debugf("certificate for %q SCT[%d] timestamp: %v (%v in the past)",
|
||||
domain, i, sctTime.UTC(), now.Sub(sctTime))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseSCTTimestamps extracts SCT timestamps from a certificate.
|
||||
func (mgr *Manager) parseSCTTimestamps(cert *x509.Certificate) []time.Time {
|
||||
var timestamps []time.Time
|
||||
|
||||
for _, ext := range cert.Extensions {
|
||||
if !ext.Id.Equal(oidSCTList) {
|
||||
continue
|
||||
}
|
||||
|
||||
// The extension value is an OCTET STRING containing the SCT list
|
||||
var sctListBytes []byte
|
||||
if _, err := asn1.Unmarshal(ext.Value, &sctListBytes); err != nil {
|
||||
mgr.logger.Debugf("failed to unmarshal SCT list outer wrapper: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// SCT list format: 2-byte length prefix, then concatenated SCTs
|
||||
if len(sctListBytes) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
listLen := int(binary.BigEndian.Uint16(sctListBytes[:2]))
|
||||
data := sctListBytes[2:]
|
||||
if len(data) < listLen {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse individual SCTs
|
||||
offset := 0
|
||||
for offset < listLen {
|
||||
if offset+2 > len(data) {
|
||||
break
|
||||
}
|
||||
sctLen := int(binary.BigEndian.Uint16(data[offset : offset+2]))
|
||||
offset += 2
|
||||
|
||||
if offset+sctLen > len(data) {
|
||||
break
|
||||
}
|
||||
sctData := data[offset : offset+sctLen]
|
||||
offset += sctLen
|
||||
|
||||
// SCT format: version (1) + log_id (32) + timestamp (8) + ...
|
||||
if len(sctData) < 41 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Timestamp is at offset 33 (after version + log_id), 8 bytes, milliseconds since epoch
|
||||
tsMillis := binary.BigEndian.Uint64(sctData[33:41])
|
||||
ts := time.UnixMilli(int64(tsMillis))
|
||||
timestamps = append(timestamps, ts)
|
||||
}
|
||||
}
|
||||
|
||||
return timestamps
|
||||
}
|
||||
|
||||
// dummyConn implements net.Conn to provide context for certificate fetching.
|
||||
type dummyConn struct {
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (c *dummyConn) Read(b []byte) (n int, err error) { return 0, nil }
|
||||
func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil }
|
||||
func (c *dummyConn) Close() error { return nil }
|
||||
func (c *dummyConn) LocalAddr() net.Addr { return nil }
|
||||
func (c *dummyConn) RemoteAddr() net.Addr { return nil }
|
||||
func (c *dummyConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
// RemoveDomain removes a domain from tracking.
|
||||
func (mgr *Manager) RemoveDomain(d domain.Domain) {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
delete(mgr.domains, d)
|
||||
}
|
||||
|
||||
// PendingCerts returns the number of certificates currently being prefetched.
|
||||
func (mgr *Manager) PendingCerts() int {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
var n int
|
||||
for _, info := range mgr.domains {
|
||||
if info.state == domainPending {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// TotalDomains returns the total number of registered domains.
|
||||
func (mgr *Manager) TotalDomains() int {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
return len(mgr.domains)
|
||||
}
|
||||
|
||||
// PendingDomains returns the domain names currently being prefetched.
|
||||
func (mgr *Manager) PendingDomains() []string {
|
||||
return mgr.domainsByState(domainPending)
|
||||
}
|
||||
|
||||
// ReadyDomains returns domain names that have successfully obtained certificates.
|
||||
func (mgr *Manager) ReadyDomains() []string {
|
||||
return mgr.domainsByState(domainReady)
|
||||
}
|
||||
|
||||
// FailedDomains returns domain names that failed certificate prefetch, mapped to their error.
|
||||
func (mgr *Manager) FailedDomains() map[string]string {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
result := make(map[string]string)
|
||||
for d, info := range mgr.domains {
|
||||
if info.state == domainFailed {
|
||||
result[d.PunycodeString()] = info.err
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (mgr *Manager) domainsByState(state domainState) []string {
|
||||
mgr.mu.RLock()
|
||||
defer mgr.mu.RUnlock()
|
||||
var domains []string
|
||||
for d, info := range mgr.domains {
|
||||
if info.state == state {
|
||||
domains = append(domains, d.PunycodeString())
|
||||
}
|
||||
}
|
||||
slices.Sort(domains)
|
||||
return domains
|
||||
}
|
||||
102
proxy/internal/acme/manager_test.go
Normal file
102
proxy/internal/acme/manager_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHostPolicy(t *testing.T) {
|
||||
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "")
|
||||
mgr.AddDomain("example.com", "acc1", "rp1")
|
||||
|
||||
// Wait for the background prefetch goroutine to finish so the temp dir
|
||||
// can be cleaned up without a race.
|
||||
t.Cleanup(func() {
|
||||
assert.Eventually(t, func() bool {
|
||||
return mgr.PendingCerts() == 0
|
||||
}, 30*time.Second, 50*time.Millisecond)
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "exact domain match",
|
||||
host: "example.com",
|
||||
},
|
||||
{
|
||||
name: "domain with port",
|
||||
host: "example.com:443",
|
||||
},
|
||||
{
|
||||
name: "unknown domain",
|
||||
host: "unknown.com",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown domain with port",
|
||||
host: "unknown.com:443",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty host",
|
||||
host: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "port only",
|
||||
host: ":443",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := mgr.hostPolicy(context.Background(), tc.host)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown domain")
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDomainStates(t *testing.T) {
|
||||
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "")
|
||||
|
||||
assert.Equal(t, 0, mgr.PendingCerts(), "initially zero")
|
||||
assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains")
|
||||
assert.Empty(t, mgr.PendingDomains())
|
||||
assert.Empty(t, mgr.ReadyDomains())
|
||||
assert.Empty(t, mgr.FailedDomains())
|
||||
|
||||
// AddDomain starts as pending, then the prefetch goroutine will fail
|
||||
// (no real ACME server) and transition to failed.
|
||||
mgr.AddDomain("a.example.com", "acc1", "rp1")
|
||||
mgr.AddDomain("b.example.com", "acc1", "rp1")
|
||||
|
||||
assert.Equal(t, 2, mgr.TotalDomains(), "two domains registered")
|
||||
|
||||
// Pending domains should eventually drain after prefetch goroutines finish.
|
||||
assert.Eventually(t, func() bool {
|
||||
return mgr.PendingCerts() == 0
|
||||
}, 30*time.Second, 100*time.Millisecond, "pending certs should return to zero after prefetch completes")
|
||||
|
||||
assert.Empty(t, mgr.PendingDomains())
|
||||
assert.Equal(t, 2, mgr.TotalDomains(), "total domains unchanged")
|
||||
|
||||
// With a fake ACME URL, both should have failed.
|
||||
failed := mgr.FailedDomains()
|
||||
assert.Len(t, failed, 2, "both domains should have failed")
|
||||
assert.Contains(t, failed, "a.example.com")
|
||||
assert.Contains(t, failed, "b.example.com")
|
||||
assert.Empty(t, mgr.ReadyDomains())
|
||||
}
|
||||
18
proxy/internal/auth/auth.gohtml
Normal file
18
proxy/internal/auth/auth.gohtml
Normal file
@@ -0,0 +1,18 @@
|
||||
<!doctype html>
|
||||
{{ range $method, $value := .Methods }}
|
||||
{{ if eq $method "pin" }}
|
||||
<form>
|
||||
<label for={{ $value }}>PIN:</label>
|
||||
<input name={{ $value }} id={{ $value }} />
|
||||
<button type=submit>Submit</button>
|
||||
</form>
|
||||
{{ else if eq $method "password" }}
|
||||
<form>
|
||||
<label for={{ $value }}>Password:</label>
|
||||
<input name={{ $value }} id={{ $value }}/>
|
||||
<button type=submit>Submit</button>
|
||||
</form>
|
||||
{{ else if eq $method "oidc" }}
|
||||
<a href={{ $value }}>Click here to log in with SSO</a>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
364
proxy/internal/auth/middleware.go
Normal file
364
proxy/internal/auth/middleware.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type authenticator interface {
|
||||
Authenticate(ctx context.Context, in *proto.AuthenticateRequest, opts ...grpc.CallOption) (*proto.AuthenticateResponse, error)
|
||||
}
|
||||
|
||||
// SessionValidator validates session tokens and checks user access permissions.
|
||||
type SessionValidator interface {
|
||||
ValidateSession(ctx context.Context, in *proto.ValidateSessionRequest, opts ...grpc.CallOption) (*proto.ValidateSessionResponse, error)
|
||||
}
|
||||
|
||||
// Scheme defines an authentication mechanism for a domain.
|
||||
type Scheme interface {
|
||||
Type() auth.Method
|
||||
// Authenticate checks the request and determines whether it represents
|
||||
// an authenticated user. An empty token indicates an unauthenticated
|
||||
// request; optionally, promptData may be returned for the login UI.
|
||||
// An error indicates an infrastructure failure (e.g. gRPC unavailable).
|
||||
Authenticate(*http.Request) (token string, promptData string, err error)
|
||||
}
|
||||
|
||||
type DomainConfig struct {
|
||||
Schemes []Scheme
|
||||
SessionPublicKey ed25519.PublicKey
|
||||
SessionExpiration time.Duration
|
||||
AccountID string
|
||||
ServiceID string
|
||||
}
|
||||
|
||||
type validationResult struct {
|
||||
UserID string
|
||||
Valid bool
|
||||
DeniedReason string
|
||||
}
|
||||
|
||||
type Middleware struct {
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string]DomainConfig
|
||||
logger *log.Logger
|
||||
sessionValidator SessionValidator
|
||||
}
|
||||
|
||||
// NewMiddleware creates a new authentication middleware.
|
||||
// The sessionValidator is optional; if nil, OIDC session tokens will be validated
|
||||
// locally without group access checks.
|
||||
func NewMiddleware(logger *log.Logger, sessionValidator SessionValidator) *Middleware {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &Middleware{
|
||||
domains: make(map[string]DomainConfig),
|
||||
logger: logger,
|
||||
sessionValidator: sessionValidator,
|
||||
}
|
||||
}
|
||||
|
||||
// Protect applies authentication middleware to the passed handler.
|
||||
// For each incoming request it will be checked against the middleware's
|
||||
// internal list of protected domains.
|
||||
// If the Host domain in the inbound request is not present, then it will
|
||||
// simply be passed through.
|
||||
// However, if the Host domain is present, then the specified authentication
|
||||
// schemes for that domain will be applied to the request.
|
||||
// In the event that no authentication schemes are defined for the domain,
|
||||
// then the request will also be simply passed through.
|
||||
func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
}
|
||||
|
||||
config, exists := mw.getDomainConfig(host)
|
||||
mw.logger.Debugf("checking authentication for host: %s, exists: %t", host, exists)
|
||||
|
||||
// Domains that are not configured here or have no authentication schemes applied should simply pass through.
|
||||
if !exists || len(config.Schemes) == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Set account and service IDs in captured data for access logging.
|
||||
setCapturedIDs(r, config)
|
||||
|
||||
if mw.handleOAuthCallbackError(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
if mw.forwardWithSessionCookie(w, r, host, config, next) {
|
||||
return
|
||||
}
|
||||
|
||||
mw.authenticateWithSchemes(w, r, host, config)
|
||||
})
|
||||
}
|
||||
|
||||
func (mw *Middleware) getDomainConfig(host string) (DomainConfig, bool) {
|
||||
mw.domainsMux.RLock()
|
||||
defer mw.domainsMux.RUnlock()
|
||||
config, exists := mw.domains[host]
|
||||
return config, exists
|
||||
}
|
||||
|
||||
func setCapturedIDs(r *http.Request, config DomainConfig) {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetAccountId(types.AccountID(config.AccountID))
|
||||
cd.SetServiceId(config.ServiceID)
|
||||
}
|
||||
}
|
||||
|
||||
// handleOAuthCallbackError checks for error query parameters from an OAuth
|
||||
// callback and renders the access denied page if present.
|
||||
func (mw *Middleware) handleOAuthCallbackError(w http.ResponseWriter, r *http.Request) bool {
|
||||
errCode := r.URL.Query().Get("error")
|
||||
if errCode == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(auth.MethodOIDC.String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
errDesc := r.URL.Query().Get("error_description")
|
||||
if errDesc == "" {
|
||||
errDesc = "An error occurred during authentication"
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", errDesc, requestID)
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardWithSessionCookie checks for a valid session cookie and, if found,
|
||||
// sets the user identity on the request context and forwards to the next handler.
|
||||
func (mw *Middleware) forwardWithSessionCookie(w http.ResponseWriter, r *http.Request, host string, config DomainConfig, next http.Handler) bool {
|
||||
cookie, err := r.Cookie(auth.SessionCookieName)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
userID, method, err := auth.ValidateSessionJWT(cookie.Value, host, config.SessionPublicKey)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetUserID(userID)
|
||||
cd.SetAuthMethod(method)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
return true
|
||||
}
|
||||
|
||||
// authenticateWithSchemes tries each configured auth scheme in order.
|
||||
// On success it sets a session cookie and redirects; on failure it renders the login page.
|
||||
func (mw *Middleware) authenticateWithSchemes(w http.ResponseWriter, r *http.Request, host string, config DomainConfig) {
|
||||
methods := make(map[string]string)
|
||||
var attemptedMethod string
|
||||
|
||||
for _, scheme := range config.Schemes {
|
||||
token, promptData, err := scheme.Authenticate(r)
|
||||
if err != nil {
|
||||
mw.logger.WithField("scheme", scheme.Type().String()).Warnf("authentication infrastructure error: %v", err)
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
}
|
||||
http.Error(w, "authentication service unavailable", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
// Track if credentials were submitted but auth failed
|
||||
if token == "" && wasCredentialSubmitted(r, scheme.Type()) {
|
||||
attemptedMethod = scheme.Type().String()
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
mw.handleAuthenticatedToken(w, r, host, token, config, scheme)
|
||||
return
|
||||
}
|
||||
methods[scheme.Type().String()] = promptData
|
||||
}
|
||||
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
if attemptedMethod != "" {
|
||||
cd.SetAuthMethod(attemptedMethod)
|
||||
}
|
||||
}
|
||||
web.ServeHTTP(w, r, map[string]any{"methods": methods}, http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// handleAuthenticatedToken validates the token, handles denied access, and on
|
||||
// success sets a session cookie and redirects to the original URL.
|
||||
func (mw *Middleware) handleAuthenticatedToken(w http.ResponseWriter, r *http.Request, host, token string, config DomainConfig, scheme Scheme) {
|
||||
result, err := mw.validateSessionToken(r.Context(), host, token, config.SessionPublicKey, scheme.Type())
|
||||
if err != nil {
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Valid {
|
||||
var requestID string
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
requestID = cd.GetRequestID()
|
||||
}
|
||||
web.ServeAccessDeniedPage(w, r, http.StatusForbidden, "Access Denied", "You are not authorized to access this service", requestID)
|
||||
return
|
||||
}
|
||||
|
||||
expiration := config.SessionExpiration
|
||||
if expiration == 0 {
|
||||
expiration = auth.DefaultSessionExpiry
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: auth.SessionCookieName,
|
||||
Value: token,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: int(expiration.Seconds()),
|
||||
})
|
||||
|
||||
// Redirect instead of forwarding the auth POST to the backend.
|
||||
// The browser will follow with a GET carrying the new session cookie.
|
||||
if cd := proxy.CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(proxy.OriginAuth)
|
||||
cd.SetUserID(result.UserID)
|
||||
cd.SetAuthMethod(scheme.Type().String())
|
||||
}
|
||||
redirectURL := stripSessionTokenParam(r.URL)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// wasCredentialSubmitted checks if credentials were submitted for the given auth method.
|
||||
func wasCredentialSubmitted(r *http.Request, method auth.Method) bool {
|
||||
switch method {
|
||||
case auth.MethodPIN:
|
||||
return r.FormValue("pin") != ""
|
||||
case auth.MethodPassword:
|
||||
return r.FormValue("password") != ""
|
||||
case auth.MethodOIDC:
|
||||
return r.URL.Query().Get("session_token") != ""
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AddDomain registers authentication schemes for the given domain.
|
||||
// If schemes are provided, a valid session public key is required to sign/verify
|
||||
// session JWTs. Returns an error if the key is missing or invalid.
|
||||
// Callers must not serve the domain if this returns an error, to avoid
|
||||
// exposing an unauthenticated service.
|
||||
func (mw *Middleware) AddDomain(domain string, schemes []Scheme, publicKeyB64 string, expiration time.Duration, accountID, serviceID string) error {
|
||||
if len(schemes) == 0 {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
mw.domains[domain] = DomainConfig{
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
pubKeyBytes, err := base64.StdEncoding.DecodeString(publicKeyB64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode session public key for domain %s: %w", domain, err)
|
||||
}
|
||||
if len(pubKeyBytes) != ed25519.PublicKeySize {
|
||||
return fmt.Errorf("invalid session public key size for domain %s: got %d, want %d", domain, len(pubKeyBytes), ed25519.PublicKeySize)
|
||||
}
|
||||
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
mw.domains[domain] = DomainConfig{
|
||||
Schemes: schemes,
|
||||
SessionPublicKey: pubKeyBytes,
|
||||
SessionExpiration: expiration,
|
||||
AccountID: accountID,
|
||||
ServiceID: serviceID,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mw *Middleware) RemoveDomain(domain string) {
|
||||
mw.domainsMux.Lock()
|
||||
defer mw.domainsMux.Unlock()
|
||||
delete(mw.domains, domain)
|
||||
}
|
||||
|
||||
// validateSessionToken validates a session token, optionally checking group access via gRPC.
|
||||
// For OIDC tokens with a configured validator, it calls ValidateSession to check group access.
|
||||
// For other auth methods (PIN, password), it validates the JWT locally.
|
||||
// Returns a validationResult with user ID and validity status, or error for invalid tokens.
|
||||
func (mw *Middleware) validateSessionToken(ctx context.Context, host, token string, publicKey ed25519.PublicKey, method auth.Method) (*validationResult, error) {
|
||||
// For OIDC with a session validator, call the gRPC service to check group access
|
||||
if method == auth.MethodOIDC && mw.sessionValidator != nil {
|
||||
resp, err := mw.sessionValidator.ValidateSession(ctx, &proto.ValidateSessionRequest{
|
||||
Domain: host,
|
||||
SessionToken: token,
|
||||
})
|
||||
if err != nil {
|
||||
mw.logger.WithError(err).Error("ValidateSession gRPC call failed")
|
||||
return nil, fmt.Errorf("session validation failed")
|
||||
}
|
||||
if !resp.Valid {
|
||||
mw.logger.WithFields(log.Fields{
|
||||
"domain": host,
|
||||
"denied_reason": resp.DeniedReason,
|
||||
"user_id": resp.UserId,
|
||||
}).Debug("Session validation denied")
|
||||
return &validationResult{
|
||||
UserID: resp.UserId,
|
||||
Valid: false,
|
||||
DeniedReason: resp.DeniedReason,
|
||||
}, nil
|
||||
}
|
||||
return &validationResult{UserID: resp.UserId, Valid: true}, nil
|
||||
}
|
||||
|
||||
// For non-OIDC methods or when no validator is configured, validate JWT locally
|
||||
userID, _, err := auth.ValidateSessionJWT(token, host, publicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &validationResult{UserID: userID, Valid: true}, nil
|
||||
}
|
||||
|
||||
// stripSessionTokenParam returns the request URI with the session_token query
|
||||
// parameter removed so it doesn't linger in the browser's address bar or history.
|
||||
func stripSessionTokenParam(u *url.URL) string {
|
||||
q := u.Query()
|
||||
if !q.Has("session_token") {
|
||||
return u.RequestURI()
|
||||
}
|
||||
q.Del("session_token")
|
||||
clean := *u
|
||||
clean.RawQuery = q.Encode()
|
||||
return clean.RequestURI()
|
||||
}
|
||||
660
proxy/internal/auth/middleware_test.go
Normal file
660
proxy/internal/auth/middleware_test.go
Normal file
@@ -0,0 +1,660 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
)
|
||||
|
||||
func generateTestKeyPair(t *testing.T) *sessionkey.KeyPair {
|
||||
t.Helper()
|
||||
kp, err := sessionkey.GenerateKeyPair()
|
||||
require.NoError(t, err)
|
||||
return kp
|
||||
}
|
||||
|
||||
// stubScheme is a minimal Scheme implementation for testing.
|
||||
type stubScheme struct {
|
||||
method auth.Method
|
||||
token string
|
||||
promptID string
|
||||
authFn func(*http.Request) (string, string, error)
|
||||
}
|
||||
|
||||
func (s *stubScheme) Type() auth.Method { return s.method }
|
||||
|
||||
func (s *stubScheme) Authenticate(r *http.Request) (string, string, error) {
|
||||
if s.authFn != nil {
|
||||
return s.authFn(r)
|
||||
}
|
||||
return s.token, s.promptID, nil
|
||||
}
|
||||
|
||||
func newPassthroughHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("backend"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddDomain_ValidKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
config, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
assert.True(t, exists, "domain should be registered")
|
||||
assert.Len(t, config.Schemes, 1)
|
||||
assert.Equal(t, ed25519.PublicKeySize, len(config.SessionPublicKey))
|
||||
assert.Equal(t, time.Hour, config.SessionExpiration)
|
||||
}
|
||||
|
||||
func TestAddDomain_EmptyKey(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "", time.Hour, "", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
_, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
assert.False(t, exists, "domain must not be registered with an empty session key")
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidBase64(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "not-valid-base64!!!", time.Hour, "", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decode session public key")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
_, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
assert.False(t, exists, "domain must not be registered with invalid base64 key")
|
||||
}
|
||||
|
||||
func TestAddDomain_WrongKeySize(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
shortKey := base64.StdEncoding.EncodeToString([]byte("tooshort"))
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, shortKey, time.Hour, "", "")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid session public key size")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
_, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
assert.False(t, exists, "domain must not be registered with a wrong-size key")
|
||||
}
|
||||
|
||||
func TestAddDomain_NoSchemes_NoKeyRequired(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
err := mw.AddDomain("example.com", nil, "", time.Hour, "", "")
|
||||
require.NoError(t, err, "domains with no auth schemes should not require a key")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
_, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestAddDomain_OverwritesPreviousConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp2.PublicKey, 2*time.Hour, "", ""))
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
config := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
pubKeyBytes, _ := base64.StdEncoding.DecodeString(kp2.PublicKey)
|
||||
assert.Equal(t, ed25519.PublicKey(pubKeyBytes), config.SessionPublicKey, "should use the latest key")
|
||||
assert.Equal(t, 2*time.Hour, config.SessionExpiration)
|
||||
}
|
||||
|
||||
func TestRemoveDomain(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
mw.RemoveDomain("example.com")
|
||||
|
||||
mw.domainsMux.RLock()
|
||||
_, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestProtect_UnknownDomainPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://unknown.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "backend", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestProtect_DomainWithNoSchemesPassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
require.NoError(t, mw.AddDomain("example.com", nil, "", time.Hour, "", ""))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "backend", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestProtect_UnauthenticatedRequestIsBlocked(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "unauthenticated request should not reach backend")
|
||||
}
|
||||
|
||||
func TestProtect_HostWithPortIsMatched(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com:8443/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "host with port should still match the protected domain")
|
||||
}
|
||||
|
||||
func TestProtect_ValidSessionCookiePassesThrough(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
handler := mw.Protect(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cd := proxy.CapturedDataFromContext(r.Context())
|
||||
require.NotNil(t, cd)
|
||||
assert.Equal(t, "test-user", cd.GetUserID())
|
||||
assert.Equal(t, "pin", cd.GetAuthMethod())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("authenticated"))
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "authenticated", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestProtect_ExpiredSessionCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
// Sign a token that expired 1 second ago.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "example.com", auth.MethodPIN, -time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "expired session should not reach the backend")
|
||||
}
|
||||
|
||||
func TestProtect_WrongDomainCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
// Token signed for a different domain audience.
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "test-user", "other.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "cookie for wrong domain should be rejected")
|
||||
}
|
||||
|
||||
func TestProtect_WrongKeyCookieIsRejected(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp1 := generateTestKeyPair(t)
|
||||
kp2 := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp1.PublicKey, time.Hour, "", ""))
|
||||
|
||||
// Token signed with a different private key.
|
||||
token, err := sessionkey.SignToken(kp2.PrivateKey, "test-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: token})
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "cookie signed by wrong key should be rejected")
|
||||
}
|
||||
|
||||
func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "pin-user", "example.com", auth.MethodPIN, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodPIN,
|
||||
authFn: func(r *http.Request) (string, string, error) {
|
||||
if r.FormValue("pin") == "111111" {
|
||||
return token, "", nil
|
||||
}
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
// Submit the PIN via form POST.
|
||||
form := url.Values{"pin": {"111111"}}
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/somepath", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "backend should not be called during auth, only a redirect should be returned")
|
||||
assert.Equal(t, http.StatusSeeOther, rec.Code)
|
||||
assert.Equal(t, "/somepath", rec.Header().Get("Location"), "redirect should point to the original request URI")
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
var sessionCookie *http.Cookie
|
||||
for _, c := range cookies {
|
||||
if c.Name == auth.SessionCookieName {
|
||||
sessionCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, sessionCookie, "session cookie should be set after successful auth")
|
||||
assert.True(t, sessionCookie.HttpOnly)
|
||||
assert.True(t, sessionCookie.Secure)
|
||||
assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite)
|
||||
}
|
||||
|
||||
func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodPIN,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
assert.NotEqual(t, auth.SessionCookieName, c.Name, "no session cookie should be set on failed auth")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtect_MultipleSchemes(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
token, err := sessionkey.SignToken(kp.PrivateKey, "password-user", "example.com", auth.MethodPassword, time.Hour)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First scheme (PIN) always fails, second scheme (password) succeeds.
|
||||
pinScheme := &stubScheme{
|
||||
method: auth.MethodPIN,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
passwordScheme := &stubScheme{
|
||||
method: auth.MethodPassword,
|
||||
authFn: func(r *http.Request) (string, string, error) {
|
||||
if r.FormValue("password") == "secret" {
|
||||
return token, "", nil
|
||||
}
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{pinScheme, passwordScheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
var backendCalled bool
|
||||
backend := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
backendCalled = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := mw.Protect(backend)
|
||||
|
||||
form := url.Values{"password": {"secret"}}
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.False(t, backendCalled, "backend should not be called during auth")
|
||||
assert.Equal(t, http.StatusSeeOther, rec.Code)
|
||||
}
|
||||
|
||||
func TestProtect_InvalidTokenFromSchemeReturns400(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Return a garbage token that won't validate.
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodPIN,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "invalid-jwt-token", "", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
func TestAddDomain_RandomBytes32NotEd25519(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
|
||||
// 32 random bytes that happen to be valid base64 and correct size
|
||||
// but are actually a valid ed25519 public key length-wise.
|
||||
// This should succeed because ed25519 public keys are just 32 bytes.
|
||||
randomBytes := make([]byte, ed25519.PublicKeySize)
|
||||
_, err := rand.Read(randomBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
key := base64.StdEncoding.EncodeToString(randomBytes)
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
|
||||
err = mw.AddDomain("example.com", []Scheme{scheme}, key, time.Hour, "", "")
|
||||
require.NoError(t, err, "any 32-byte key should be accepted at registration time")
|
||||
}
|
||||
|
||||
func TestAddDomain_InvalidKeyDoesNotCorruptExistingConfig(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{method: auth.MethodPIN, promptID: "pin"}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
// Attempt to overwrite with an invalid key.
|
||||
err := mw.AddDomain("example.com", []Scheme{scheme}, "bad", time.Hour, "", "")
|
||||
require.Error(t, err)
|
||||
|
||||
// The original valid config should still be intact.
|
||||
mw.domainsMux.RLock()
|
||||
config, exists := mw.domains["example.com"]
|
||||
mw.domainsMux.RUnlock()
|
||||
|
||||
assert.True(t, exists, "original config should still exist")
|
||||
assert.Len(t, config.Schemes, 1)
|
||||
assert.Equal(t, time.Hour, config.SessionExpiration)
|
||||
}
|
||||
|
||||
func TestProtect_FailedPinAuthCapturesAuthMethod(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
// Scheme that always fails authentication (returns empty token)
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodPIN,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// Submit wrong PIN - should capture auth method
|
||||
form := url.Values{"pin": {"wrong-pin"}}
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Equal(t, "pin", capturedData.GetAuthMethod(), "Auth method should be captured for failed PIN auth")
|
||||
}
|
||||
|
||||
func TestProtect_FailedPasswordAuthCapturesAuthMethod(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodPassword,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "", "password", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// Submit wrong password - should capture auth method
|
||||
form := url.Values{"password": {"wrong-password"}}
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Equal(t, "password", capturedData.GetAuthMethod(), "Auth method should be captured for failed password auth")
|
||||
}
|
||||
|
||||
func TestProtect_NoCredentialsDoesNotCaptureAuthMethod(t *testing.T) {
|
||||
mw := NewMiddleware(log.StandardLogger(), nil)
|
||||
kp := generateTestKeyPair(t)
|
||||
|
||||
scheme := &stubScheme{
|
||||
method: auth.MethodPIN,
|
||||
authFn: func(_ *http.Request) (string, string, error) {
|
||||
return "", "pin", nil
|
||||
},
|
||||
}
|
||||
require.NoError(t, mw.AddDomain("example.com", []Scheme{scheme}, kp.PublicKey, time.Hour, "", ""))
|
||||
|
||||
capturedData := &proxy.CapturedData{}
|
||||
handler := mw.Protect(newPassthroughHandler())
|
||||
|
||||
// No credentials submitted - should not capture auth method
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
req = req.WithContext(proxy.WithCapturedData(req.Context(), capturedData))
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
assert.Empty(t, capturedData.GetAuthMethod(), "Auth method should not be captured when no credentials submitted")
|
||||
}
|
||||
|
||||
func TestWasCredentialSubmitted(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method auth.Method
|
||||
formData url.Values
|
||||
query url.Values
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "PIN submitted",
|
||||
method: auth.MethodPIN,
|
||||
formData: url.Values{"pin": {"123456"}},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PIN not submitted",
|
||||
method: auth.MethodPIN,
|
||||
formData: url.Values{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Password submitted",
|
||||
method: auth.MethodPassword,
|
||||
formData: url.Values{"password": {"secret"}},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Password not submitted",
|
||||
method: auth.MethodPassword,
|
||||
formData: url.Values{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "OIDC token in query",
|
||||
method: auth.MethodOIDC,
|
||||
query: url.Values{"session_token": {"abc123"}},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "OIDC token not in query",
|
||||
method: auth.MethodOIDC,
|
||||
query: url.Values{},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reqURL := "http://example.com/"
|
||||
if len(tt.query) > 0 {
|
||||
reqURL += "?" + tt.query.Encode()
|
||||
}
|
||||
|
||||
var body *strings.Reader
|
||||
if len(tt.formData) > 0 {
|
||||
body = strings.NewReader(tt.formData.Encode())
|
||||
} else {
|
||||
body = strings.NewReader("")
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, reqURL, body)
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
result := wasCredentialSubmitted(req, tt.method)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
65
proxy/internal/auth/oidc.go
Normal file
65
proxy/internal/auth/oidc.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type urlGenerator interface {
|
||||
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
|
||||
}
|
||||
|
||||
type OIDC struct {
|
||||
id string
|
||||
accountId string
|
||||
forwardedProto string
|
||||
client urlGenerator
|
||||
}
|
||||
|
||||
// NewOIDC creates a new OIDC authentication scheme
|
||||
func NewOIDC(client urlGenerator, id, accountId, forwardedProto string) OIDC {
|
||||
return OIDC{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
forwardedProto: forwardedProto,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (OIDC) Type() auth.Method {
|
||||
return auth.MethodOIDC
|
||||
}
|
||||
|
||||
// Authenticate checks for an OIDC session token or obtains the OIDC redirect URL.
|
||||
func (o OIDC) Authenticate(r *http.Request) (string, string, error) {
|
||||
// Check for the session_token query param (from OIDC redirects).
|
||||
// The management server passes the token in the URL because it cannot set
|
||||
// cookies for the proxy's domain (cookies are domain-scoped per RFC 6265).
|
||||
if token := r.URL.Query().Get("session_token"); token != "" {
|
||||
return token, "", nil
|
||||
}
|
||||
|
||||
redirectURL := &url.URL{
|
||||
Scheme: auth.ResolveProto(o.forwardedProto, r.TLS),
|
||||
Host: r.Host,
|
||||
Path: r.URL.Path,
|
||||
}
|
||||
|
||||
res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{
|
||||
Id: o.id,
|
||||
AccountId: o.accountId,
|
||||
RedirectUrl: redirectURL.String(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("get OIDC URL: %w", err)
|
||||
}
|
||||
|
||||
return "", res.GetUrl(), nil
|
||||
}
|
||||
61
proxy/internal/auth/password.go
Normal file
61
proxy/internal/auth/password.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const passwordFormId = "password"
|
||||
|
||||
type Password struct {
|
||||
id, accountId string
|
||||
client authenticator
|
||||
}
|
||||
|
||||
func NewPassword(client authenticator, id, accountId string) Password {
|
||||
return Password{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (Password) Type() auth.Method {
|
||||
return auth.MethodPassword
|
||||
}
|
||||
|
||||
// Authenticate attempts to authenticate the request using a form
|
||||
// value passed in the request.
|
||||
// If authentication fails, the required HTTP form ID is returned
|
||||
// so that it can be injected into a request from the UI so that
|
||||
// authentication may be successful.
|
||||
func (p Password) Authenticate(r *http.Request) (string, string, error) {
|
||||
password := r.FormValue(passwordFormId)
|
||||
|
||||
if password == "" {
|
||||
// No password submitted; return the form ID so the UI can prompt the user.
|
||||
return "", passwordFormId, nil
|
||||
}
|
||||
|
||||
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
|
||||
Id: p.id,
|
||||
AccountId: p.accountId,
|
||||
Request: &proto.AuthenticateRequest_Password{
|
||||
Password: &proto.PasswordRequest{
|
||||
Password: password,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("authenticate password: %w", err)
|
||||
}
|
||||
|
||||
if res.GetSuccess() {
|
||||
return res.GetSessionToken(), "", nil
|
||||
}
|
||||
|
||||
return "", passwordFormId, nil
|
||||
}
|
||||
61
proxy/internal/auth/pin.go
Normal file
61
proxy/internal/auth/pin.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
const pinFormId = "pin"
|
||||
|
||||
type Pin struct {
|
||||
id, accountId string
|
||||
client authenticator
|
||||
}
|
||||
|
||||
func NewPin(client authenticator, id, accountId string) Pin {
|
||||
return Pin{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (Pin) Type() auth.Method {
|
||||
return auth.MethodPIN
|
||||
}
|
||||
|
||||
// Authenticate attempts to authenticate the request using a form
|
||||
// value passed in the request.
|
||||
// If authentication fails, the required HTTP form ID is returned
|
||||
// so that it can be injected into a request from the UI so that
|
||||
// authentication may be successful.
|
||||
func (p Pin) Authenticate(r *http.Request) (string, string, error) {
|
||||
pin := r.FormValue(pinFormId)
|
||||
|
||||
if pin == "" {
|
||||
// No PIN submitted; return the form ID so the UI can prompt the user.
|
||||
return "", pinFormId, nil
|
||||
}
|
||||
|
||||
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
|
||||
Id: p.id,
|
||||
AccountId: p.accountId,
|
||||
Request: &proto.AuthenticateRequest_Pin{
|
||||
Pin: &proto.PinRequest{
|
||||
Pin: pin,
|
||||
},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("authenticate pin: %w", err)
|
||||
}
|
||||
|
||||
if res.GetSuccess() {
|
||||
return res.GetSessionToken(), "", nil
|
||||
}
|
||||
|
||||
return "", pinFormId, nil
|
||||
}
|
||||
279
proxy/internal/certwatch/watcher.go
Normal file
279
proxy/internal/certwatch/watcher.go
Normal file
@@ -0,0 +1,279 @@
|
||||
// Package certwatch watches TLS certificate files on disk and provides
|
||||
// a hot-reloading GetCertificate callback for tls.Config.
|
||||
package certwatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPollInterval = 30 * time.Second
|
||||
debounceDelay = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// Watcher monitors TLS certificate files on disk and caches the loaded
|
||||
// certificate in memory. It detects changes via fsnotify (with a polling
|
||||
// fallback for filesystems like NFS that lack inotify support) and
|
||||
// reloads the certificate pair automatically.
|
||||
type Watcher struct {
|
||||
certPath string
|
||||
keyPath string
|
||||
|
||||
mu sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
leaf *x509.Certificate
|
||||
|
||||
pollInterval time.Duration
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewWatcher creates a Watcher that monitors the given cert and key files.
|
||||
// It performs an initial load of the certificate and returns an error
|
||||
// if the initial load fails.
|
||||
func NewWatcher(certPath, keyPath string, logger *log.Logger) (*Watcher, error) {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
|
||||
w := &Watcher{
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
pollInterval: defaultPollInterval,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if err := w.reload(); err != nil {
|
||||
return nil, fmt.Errorf("initial certificate load: %w", err)
|
||||
}
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// GetCertificate returns the current in-memory certificate.
|
||||
// It is safe for concurrent use and compatible with tls.Config.GetCertificate.
|
||||
func (w *Watcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
|
||||
return w.cert, nil
|
||||
}
|
||||
|
||||
// Watch starts watching for certificate file changes. It blocks until
|
||||
// ctx is cancelled. It uses fsnotify for immediate detection and falls
|
||||
// back to polling if fsnotify is unavailable (e.g. on NFS).
|
||||
// Even with fsnotify active, a periodic poll runs as a safety net.
|
||||
func (w *Watcher) Watch(ctx context.Context) {
|
||||
// Watch the parent directory rather than individual files. Some volume
|
||||
// mounts use an atomic symlink swap (..data -> timestamped dir), so
|
||||
// watching the parent directory catches the link replacement.
|
||||
certDir := filepath.Dir(w.certPath)
|
||||
keyDir := filepath.Dir(w.keyPath)
|
||||
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
w.logger.Warnf("fsnotify unavailable, using polling only: %v", err)
|
||||
w.pollLoop(ctx)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := watcher.Close(); err != nil {
|
||||
w.logger.Debugf("close fsnotify watcher: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := watcher.Add(certDir); err != nil {
|
||||
w.logger.Warnf("fsnotify watch on %s failed, using polling only: %v", certDir, err)
|
||||
w.pollLoop(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
if keyDir != certDir {
|
||||
if err := watcher.Add(keyDir); err != nil {
|
||||
w.logger.Warnf("fsnotify watch on %s failed: %v", keyDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
w.logger.Infof("watching certificate files in %s", certDir)
|
||||
w.fsnotifyLoop(ctx, watcher)
|
||||
}
|
||||
|
||||
func (w *Watcher) fsnotifyLoop(ctx context.Context, watcher *fsnotify.Watcher) {
|
||||
certBase := filepath.Base(w.certPath)
|
||||
keyBase := filepath.Base(w.keyPath)
|
||||
|
||||
var debounce *time.Timer
|
||||
defer func() {
|
||||
if debounce != nil {
|
||||
debounce.Stop()
|
||||
}
|
||||
}()
|
||||
|
||||
// Periodic poll as a safety net for missed fsnotify events.
|
||||
pollTicker := time.NewTicker(w.pollInterval)
|
||||
defer pollTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
base := filepath.Base(event.Name)
|
||||
if !isRelevantFile(base, certBase, keyBase) {
|
||||
w.logger.Debugf("fsnotify: ignoring event %s on %s", event.Op, event.Name)
|
||||
continue
|
||||
}
|
||||
if !event.Has(fsnotify.Create) && !event.Has(fsnotify.Write) && !event.Has(fsnotify.Rename) {
|
||||
w.logger.Debugf("fsnotify: ignoring op %s on %s", event.Op, base)
|
||||
continue
|
||||
}
|
||||
|
||||
w.logger.Debugf("fsnotify: detected %s on %s, scheduling reload", event.Op, base)
|
||||
|
||||
// Debounce: cert-manager may write cert and key as separate
|
||||
// operations. Wait briefly to load both at once.
|
||||
if debounce != nil {
|
||||
debounce.Stop()
|
||||
}
|
||||
debounce = time.AfterFunc(debounceDelay, func() {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
w.tryReload()
|
||||
})
|
||||
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
w.logger.Warnf("fsnotify error: %v", err)
|
||||
|
||||
case <-pollTicker.C:
|
||||
w.tryReload()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) pollLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(w.pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.tryReload()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reload loads the certificate from disk and updates the in-memory cache.
|
||||
func (w *Watcher) reload() error {
|
||||
cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the leaf for comparison on subsequent reloads.
|
||||
if cert.Leaf == nil && len(cert.Certificate) > 0 {
|
||||
leaf, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse leaf certificate: %w", err)
|
||||
}
|
||||
cert.Leaf = leaf
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
w.cert = &cert
|
||||
w.leaf = cert.Leaf
|
||||
w.mu.Unlock()
|
||||
|
||||
w.logCertDetails("loaded certificate", cert.Leaf)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tryReload attempts to reload the certificate. It skips the update
|
||||
// if the certificate on disk is identical to the one in memory (same
|
||||
// serial number and issuer) to avoid redundant log noise.
|
||||
func (w *Watcher) tryReload() {
|
||||
cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath)
|
||||
if err != nil {
|
||||
w.logger.Warnf("reload certificate: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if cert.Leaf == nil && len(cert.Certificate) > 0 {
|
||||
leaf, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
w.logger.Warnf("parse reloaded leaf certificate: %v", err)
|
||||
return
|
||||
}
|
||||
cert.Leaf = leaf
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
|
||||
if w.leaf != nil && cert.Leaf != nil &&
|
||||
w.leaf.SerialNumber.Cmp(cert.Leaf.SerialNumber) == 0 &&
|
||||
w.leaf.Issuer.CommonName == cert.Leaf.Issuer.CommonName {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
prev := w.leaf
|
||||
w.cert = &cert
|
||||
w.leaf = cert.Leaf
|
||||
w.mu.Unlock()
|
||||
|
||||
w.logCertChange(prev, cert.Leaf)
|
||||
}
|
||||
|
||||
func (w *Watcher) logCertDetails(msg string, leaf *x509.Certificate) {
|
||||
if leaf == nil {
|
||||
w.logger.Info(msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.logger.Infof("%s: subject=%q serial=%s SANs=%v notAfter=%s",
|
||||
msg,
|
||||
leaf.Subject.CommonName,
|
||||
leaf.SerialNumber.Text(16),
|
||||
leaf.DNSNames,
|
||||
leaf.NotAfter.UTC().Format(time.RFC3339),
|
||||
)
|
||||
}
|
||||
|
||||
func (w *Watcher) logCertChange(prev, next *x509.Certificate) {
|
||||
if prev == nil || next == nil {
|
||||
w.logCertDetails("certificate reloaded from disk", next)
|
||||
return
|
||||
}
|
||||
|
||||
w.logger.Infof("certificate reloaded from disk: subject=%q -> %q serial=%s -> %s notAfter=%s -> %s",
|
||||
prev.Subject.CommonName, next.Subject.CommonName,
|
||||
prev.SerialNumber.Text(16), next.SerialNumber.Text(16),
|
||||
prev.NotAfter.UTC().Format(time.RFC3339), next.NotAfter.UTC().Format(time.RFC3339),
|
||||
)
|
||||
}
|
||||
|
||||
// isRelevantFile returns true if the changed file name is one we care about.
|
||||
// This includes the cert/key files themselves and the ..data symlink used
|
||||
// by atomic volume mounts.
|
||||
func isRelevantFile(changed, certBase, keyBase string) bool {
|
||||
return changed == certBase || changed == keyBase || changed == "..data"
|
||||
}
|
||||
292
proxy/internal/certwatch/watcher_test.go
Normal file
292
proxy/internal/certwatch/watcher_test.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package certwatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func generateSelfSignedCert(t *testing.T, serial int64) (certPEM, keyPEM []byte) {
|
||||
t.Helper()
|
||||
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(serial),
|
||||
Subject: pkix.Name{CommonName: "test"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
|
||||
return certPEM, keyPEM
|
||||
}
|
||||
|
||||
func writeCert(t *testing.T, dir string, certPEM, keyPEM []byte) {
|
||||
t.Helper()
|
||||
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "tls.crt"), certPEM, 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "tls.key"), keyPEM, 0o600))
|
||||
}
|
||||
|
||||
func TestNewWatcher(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM, keyPEM := generateSelfSignedCert(t, 1)
|
||||
writeCert(t, dir, certPEM, keyPEM)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cert)
|
||||
assert.Equal(t, int64(1), cert.Leaf.SerialNumber.Int64())
|
||||
}
|
||||
|
||||
func TestNewWatcherMissingFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
_, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestReload(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM1, keyPEM1 := generateSelfSignedCert(t, 100)
|
||||
writeCert(t, dir, certPEM1, keyPEM1)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert1, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(100), cert1.Leaf.SerialNumber.Int64())
|
||||
|
||||
// Write a new cert with a different serial.
|
||||
certPEM2, keyPEM2 := generateSelfSignedCert(t, 200)
|
||||
writeCert(t, dir, certPEM2, keyPEM2)
|
||||
|
||||
// Manually trigger reload.
|
||||
w.tryReload()
|
||||
|
||||
cert2, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(200), cert2.Leaf.SerialNumber.Int64())
|
||||
}
|
||||
|
||||
func TestTryReloadSkipsUnchanged(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM, keyPEM := generateSelfSignedCert(t, 42)
|
||||
writeCert(t, dir, certPEM, keyPEM)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert1, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reload with same cert - pointer should remain the same.
|
||||
w.tryReload()
|
||||
|
||||
cert2, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, cert1, cert2, "cert pointer should not change when content is the same")
|
||||
}
|
||||
|
||||
func TestWatchDetectsChanges(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
|
||||
writeCert(t, dir, certPEM1, keyPEM1)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use a short poll interval for the test.
|
||||
w.pollInterval = 100 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go w.Watch(ctx)
|
||||
|
||||
// Write new cert.
|
||||
certPEM2, keyPEM2 := generateSelfSignedCert(t, 999)
|
||||
writeCert(t, dir, certPEM2, keyPEM2)
|
||||
|
||||
// Wait for the watcher to pick it up.
|
||||
require.Eventually(t, func() bool {
|
||||
cert, err := w.GetCertificate(nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cert.Leaf.SerialNumber.Int64() == 999
|
||||
}, 5*time.Second, 50*time.Millisecond, "watcher should detect cert change")
|
||||
}
|
||||
|
||||
func TestIsRelevantFile(t *testing.T) {
|
||||
assert.True(t, isRelevantFile("tls.crt", "tls.crt", "tls.key"))
|
||||
assert.True(t, isRelevantFile("tls.key", "tls.crt", "tls.key"))
|
||||
assert.True(t, isRelevantFile("..data", "tls.crt", "tls.key"))
|
||||
assert.False(t, isRelevantFile("other.txt", "tls.crt", "tls.key"))
|
||||
}
|
||||
|
||||
// TestWatchSymlinkRotation simulates Kubernetes secret volume updates where
|
||||
// the data directory is atomically swapped via a ..data symlink.
|
||||
func TestWatchSymlinkRotation(t *testing.T) {
|
||||
base := t.TempDir()
|
||||
|
||||
// Create initial target directory with certs.
|
||||
dir1 := filepath.Join(base, "dir1")
|
||||
require.NoError(t, os.Mkdir(dir1, 0o755))
|
||||
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir1, "tls.crt"), certPEM1, 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir1, "tls.key"), keyPEM1, 0o600))
|
||||
|
||||
// Create ..data symlink pointing to dir1.
|
||||
dataLink := filepath.Join(base, "..data")
|
||||
require.NoError(t, os.Symlink(dir1, dataLink))
|
||||
|
||||
// Create tls.crt and tls.key as symlinks to ..data/{file}.
|
||||
certLink := filepath.Join(base, "tls.crt")
|
||||
keyLink := filepath.Join(base, "tls.key")
|
||||
require.NoError(t, os.Symlink(filepath.Join(dataLink, "tls.crt"), certLink))
|
||||
require.NoError(t, os.Symlink(filepath.Join(dataLink, "tls.key"), keyLink))
|
||||
|
||||
w, err := NewWatcher(certLink, keyLink, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), cert.Leaf.SerialNumber.Int64())
|
||||
|
||||
w.pollInterval = 100 * time.Millisecond
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go w.Watch(ctx)
|
||||
|
||||
// Simulate k8s atomic rotation: create dir2, swap ..data symlink.
|
||||
dir2 := filepath.Join(base, "dir2")
|
||||
require.NoError(t, os.Mkdir(dir2, 0o755))
|
||||
certPEM2, keyPEM2 := generateSelfSignedCert(t, 777)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir2, "tls.crt"), certPEM2, 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir2, "tls.key"), keyPEM2, 0o600))
|
||||
|
||||
// Atomic swap: create temp link, then rename over ..data.
|
||||
tmpLink := filepath.Join(base, "..data_tmp")
|
||||
require.NoError(t, os.Symlink(dir2, tmpLink))
|
||||
require.NoError(t, os.Rename(tmpLink, dataLink))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
cert, err := w.GetCertificate(nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cert.Leaf.SerialNumber.Int64() == 777
|
||||
}, 5*time.Second, 50*time.Millisecond, "watcher should detect symlink rotation")
|
||||
}
|
||||
|
||||
// TestPollLoopDetectsChanges verifies the poll-only fallback path works.
|
||||
func TestPollLoopDetectsChanges(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
|
||||
writeCert(t, dir, certPEM1, keyPEM1)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
w.pollInterval = 100 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Directly use pollLoop to test the fallback path.
|
||||
go w.pollLoop(ctx)
|
||||
|
||||
certPEM2, keyPEM2 := generateSelfSignedCert(t, 555)
|
||||
writeCert(t, dir, certPEM2, keyPEM2)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
cert, err := w.GetCertificate(nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cert.Leaf.SerialNumber.Int64() == 555
|
||||
}, 5*time.Second, 50*time.Millisecond, "poll loop should detect cert change")
|
||||
}
|
||||
|
||||
func TestGetCertificateConcurrency(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM, keyPEM := generateSelfSignedCert(t, 1)
|
||||
writeCert(t, dir, certPEM, keyPEM)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Hammer GetCertificate concurrently while reloading.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
w.tryReload()
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
cert, err := w.GetCertificate(&tls.ClientHelloInfo{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cert)
|
||||
}
|
||||
|
||||
<-done
|
||||
}
|
||||
388
proxy/internal/debug/client.go
Normal file
388
proxy/internal/debug/client.go
Normal file
@@ -0,0 +1,388 @@
|
||||
// Package debug provides HTTP debug endpoints and CLI client for the proxy server.
|
||||
package debug
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StatusFilters contains filter options for status queries.
|
||||
type StatusFilters struct {
|
||||
IPs []string
|
||||
Names []string
|
||||
Status string
|
||||
ConnectionType string
|
||||
}
|
||||
|
||||
// Client provides CLI access to debug endpoints.
|
||||
type Client struct {
|
||||
baseURL string
|
||||
jsonOutput bool
|
||||
httpClient *http.Client
|
||||
out io.Writer
|
||||
}
|
||||
|
||||
// NewClient creates a new debug client.
|
||||
func NewClient(baseURL string, jsonOutput bool, out io.Writer) *Client {
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
baseURL = "http://" + baseURL
|
||||
}
|
||||
baseURL = strings.TrimSuffix(baseURL, "/")
|
||||
|
||||
return &Client{
|
||||
baseURL: baseURL,
|
||||
jsonOutput: jsonOutput,
|
||||
out: out,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Health fetches the health status.
|
||||
func (c *Client) Health(ctx context.Context) error {
|
||||
return c.fetchAndPrint(ctx, "/debug/health", c.printHealth)
|
||||
}
|
||||
|
||||
func (c *Client) printHealth(data map[string]any) {
|
||||
_, _ = fmt.Fprintf(c.out, "Status: %v\n", data["status"])
|
||||
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
|
||||
_, _ = fmt.Fprintf(c.out, "Management Connected: %s\n", boolIcon(data["management_connected"]))
|
||||
_, _ = fmt.Fprintf(c.out, "All Clients Healthy: %s\n", boolIcon(data["all_clients_healthy"]))
|
||||
|
||||
total, _ := data["certs_total"].(float64)
|
||||
ready, _ := data["certs_ready"].(float64)
|
||||
pending, _ := data["certs_pending"].(float64)
|
||||
failed, _ := data["certs_failed"].(float64)
|
||||
if total > 0 {
|
||||
_, _ = fmt.Fprintf(c.out, "Certificates: %d ready, %d pending, %d failed (%d total)\n",
|
||||
int(ready), int(pending), int(failed), int(total))
|
||||
}
|
||||
if domains, ok := data["certs_ready_domains"].([]any); ok && len(domains) > 0 {
|
||||
_, _ = fmt.Fprintf(c.out, " Ready:\n")
|
||||
for _, d := range domains {
|
||||
_, _ = fmt.Fprintf(c.out, " %v\n", d)
|
||||
}
|
||||
}
|
||||
if domains, ok := data["certs_pending_domains"].([]any); ok && len(domains) > 0 {
|
||||
_, _ = fmt.Fprintf(c.out, " Pending:\n")
|
||||
for _, d := range domains {
|
||||
_, _ = fmt.Fprintf(c.out, " %v\n", d)
|
||||
}
|
||||
}
|
||||
if domains, ok := data["certs_failed_domains"].(map[string]any); ok && len(domains) > 0 {
|
||||
_, _ = fmt.Fprintf(c.out, " Failed:\n")
|
||||
for d, errMsg := range domains {
|
||||
_, _ = fmt.Fprintf(c.out, " %s: %v\n", d, errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
c.printHealthClients(data)
|
||||
}
|
||||
|
||||
func (c *Client) printHealthClients(data map[string]any) {
|
||||
clients, ok := data["clients"].(map[string]any)
|
||||
if !ok || len(clients) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(c.out, "\n%-38s %-9s %-7s %-8s %-8s %-16s %s\n",
|
||||
"ACCOUNT ID", "HEALTHY", "MGMT", "SIGNAL", "RELAYS", "PEERS (P2P/RLY)", "DEGRADED")
|
||||
_, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110))
|
||||
|
||||
for accountID, v := range clients {
|
||||
ch, ok := v.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
healthy := boolIcon(ch["healthy"])
|
||||
mgmt := boolIcon(ch["management_connected"])
|
||||
signal := boolIcon(ch["signal_connected"])
|
||||
|
||||
relaysConn, _ := ch["relays_connected"].(float64)
|
||||
relaysTotal, _ := ch["relays_total"].(float64)
|
||||
relays := fmt.Sprintf("%d/%d", int(relaysConn), int(relaysTotal))
|
||||
|
||||
peersConnected, _ := ch["peers_connected"].(float64)
|
||||
peersTotal, _ := ch["peers_total"].(float64)
|
||||
peersP2P, _ := ch["peers_p2p"].(float64)
|
||||
peersRelayed, _ := ch["peers_relayed"].(float64)
|
||||
peersDegraded, _ := ch["peers_degraded"].(float64)
|
||||
peers := fmt.Sprintf("%d/%d (%d/%d)", int(peersConnected), int(peersTotal), int(peersP2P), int(peersRelayed))
|
||||
degraded := fmt.Sprintf("%d", int(peersDegraded))
|
||||
|
||||
_, _ = fmt.Fprintf(c.out, "%-38s %-9s %-7s %-8s %-8s %-16s %s", accountID, healthy, mgmt, signal, relays, peers, degraded)
|
||||
if errMsg, ok := ch["error"].(string); ok && errMsg != "" {
|
||||
_, _ = fmt.Fprintf(c.out, " (%s)", errMsg)
|
||||
}
|
||||
_, _ = fmt.Fprintln(c.out)
|
||||
}
|
||||
}
|
||||
|
||||
func boolIcon(v any) string {
|
||||
b, ok := v.(bool)
|
||||
if !ok {
|
||||
return "?"
|
||||
}
|
||||
if b {
|
||||
return "yes"
|
||||
}
|
||||
return "no"
|
||||
}
|
||||
|
||||
// ListClients fetches the list of all clients.
|
||||
func (c *Client) ListClients(ctx context.Context) error {
|
||||
return c.fetchAndPrint(ctx, "/debug/clients", c.printClients)
|
||||
}
|
||||
|
||||
func (c *Client) printClients(data map[string]any) {
|
||||
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
|
||||
_, _ = fmt.Fprintf(c.out, "Clients: %v\n\n", data["client_count"])
|
||||
|
||||
clients, ok := data["clients"].([]any)
|
||||
if !ok || len(clients) == 0 {
|
||||
_, _ = fmt.Fprintln(c.out, "No clients connected.")
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(c.out, "%-38s %-12s %-40s %s\n", "ACCOUNT ID", "AGE", "DOMAINS", "HAS CLIENT")
|
||||
_, _ = fmt.Fprintln(c.out, strings.Repeat("-", 110))
|
||||
|
||||
for _, item := range clients {
|
||||
c.printClientRow(item)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) printClientRow(item any) {
|
||||
client, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
domains := c.extractDomains(client)
|
||||
hasClient := "no"
|
||||
if hc, ok := client["has_client"].(bool); ok && hc {
|
||||
hasClient = "yes"
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(c.out, "%-38s %-12v %s %s\n",
|
||||
client["account_id"],
|
||||
client["age"],
|
||||
domains,
|
||||
hasClient,
|
||||
)
|
||||
}
|
||||
|
||||
func (c *Client) extractDomains(client map[string]any) string {
|
||||
d, ok := client["domains"].([]any)
|
||||
if !ok || len(d) == 0 {
|
||||
return "-"
|
||||
}
|
||||
|
||||
parts := make([]string, len(d))
|
||||
for i, domain := range d {
|
||||
parts[i] = fmt.Sprint(domain)
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// ClientStatus fetches the status of a specific client.
|
||||
func (c *Client) ClientStatus(ctx context.Context, accountID string, filters StatusFilters) error {
|
||||
params := url.Values{}
|
||||
if len(filters.IPs) > 0 {
|
||||
params.Set("filter-by-ips", strings.Join(filters.IPs, ","))
|
||||
}
|
||||
if len(filters.Names) > 0 {
|
||||
params.Set("filter-by-names", strings.Join(filters.Names, ","))
|
||||
}
|
||||
if filters.Status != "" {
|
||||
params.Set("filter-by-status", filters.Status)
|
||||
}
|
||||
if filters.ConnectionType != "" {
|
||||
params.Set("filter-by-connection-type", filters.ConnectionType)
|
||||
}
|
||||
|
||||
path := "/debug/clients/" + url.PathEscape(accountID)
|
||||
if len(params) > 0 {
|
||||
path += "?" + params.Encode()
|
||||
}
|
||||
return c.fetchAndPrint(ctx, path, c.printClientStatus)
|
||||
}
|
||||
|
||||
func (c *Client) printClientStatus(data map[string]any) {
|
||||
_, _ = fmt.Fprintf(c.out, "Account: %v\n\n", data["account_id"])
|
||||
if status, ok := data["status"].(string); ok {
|
||||
_, _ = fmt.Fprint(c.out, status)
|
||||
}
|
||||
}
|
||||
|
||||
// ClientSyncResponse fetches the sync response of a specific client.
|
||||
func (c *Client) ClientSyncResponse(ctx context.Context, accountID string) error {
|
||||
path := "/debug/clients/" + url.PathEscape(accountID) + "/syncresponse"
|
||||
return c.fetchAndPrintJSON(ctx, path)
|
||||
}
|
||||
|
||||
// PingTCP performs a TCP ping through a client.
|
||||
func (c *Client) PingTCP(ctx context.Context, accountID, host string, port int, timeout string) error {
|
||||
params := url.Values{}
|
||||
params.Set("host", host)
|
||||
params.Set("port", fmt.Sprintf("%d", port))
|
||||
if timeout != "" {
|
||||
params.Set("timeout", timeout)
|
||||
}
|
||||
|
||||
path := fmt.Sprintf("/debug/clients/%s/pingtcp?%s", url.PathEscape(accountID), params.Encode())
|
||||
return c.fetchAndPrint(ctx, path, c.printPingResult)
|
||||
}
|
||||
|
||||
func (c *Client) printPingResult(data map[string]any) {
|
||||
success, _ := data["success"].(bool)
|
||||
if success {
|
||||
_, _ = fmt.Fprintf(c.out, "Success: %v:%v\n", data["host"], data["port"])
|
||||
_, _ = fmt.Fprintf(c.out, "Latency: %v\n", data["latency"])
|
||||
} else {
|
||||
_, _ = fmt.Fprintf(c.out, "Failed: %v:%v\n", data["host"], data["port"])
|
||||
c.printError(data)
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level of a specific client.
|
||||
func (c *Client) SetLogLevel(ctx context.Context, accountID, level string) error {
|
||||
params := url.Values{}
|
||||
params.Set("level", level)
|
||||
|
||||
path := fmt.Sprintf("/debug/clients/%s/loglevel?%s", url.PathEscape(accountID), params.Encode())
|
||||
return c.fetchAndPrint(ctx, path, c.printLogLevelResult)
|
||||
}
|
||||
|
||||
func (c *Client) printLogLevelResult(data map[string]any) {
|
||||
success, _ := data["success"].(bool)
|
||||
if success {
|
||||
_, _ = fmt.Fprintf(c.out, "Log level set to: %v\n", data["level"])
|
||||
} else {
|
||||
_, _ = fmt.Fprintln(c.out, "Failed to set log level")
|
||||
c.printError(data)
|
||||
}
|
||||
}
|
||||
|
||||
// StartClient starts a specific client.
|
||||
func (c *Client) StartClient(ctx context.Context, accountID string) error {
|
||||
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"
|
||||
return c.fetchAndPrint(ctx, path, c.printStartResult)
|
||||
}
|
||||
|
||||
func (c *Client) printStartResult(data map[string]any) {
|
||||
success, _ := data["success"].(bool)
|
||||
if success {
|
||||
_, _ = fmt.Fprintln(c.out, "Client started")
|
||||
} else {
|
||||
_, _ = fmt.Fprintln(c.out, "Failed to start client")
|
||||
c.printError(data)
|
||||
}
|
||||
}
|
||||
|
||||
// StopClient stops a specific client.
|
||||
func (c *Client) StopClient(ctx context.Context, accountID string) error {
|
||||
path := "/debug/clients/" + url.PathEscape(accountID) + "/stop"
|
||||
return c.fetchAndPrint(ctx, path, c.printStopResult)
|
||||
}
|
||||
|
||||
func (c *Client) printStopResult(data map[string]any) {
|
||||
success, _ := data["success"].(bool)
|
||||
if success {
|
||||
_, _ = fmt.Fprintln(c.out, "Client stopped")
|
||||
} else {
|
||||
_, _ = fmt.Fprintln(c.out, "Failed to stop client")
|
||||
c.printError(data)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) printError(data map[string]any) {
|
||||
if errMsg, ok := data["error"].(string); ok {
|
||||
_, _ = fmt.Fprintf(c.out, "Error: %s\n", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) fetchAndPrint(ctx context.Context, path string, printer func(map[string]any)) error {
|
||||
data, raw, err := c.fetch(ctx, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.jsonOutput {
|
||||
return c.writeJSON(data)
|
||||
}
|
||||
|
||||
if data != nil {
|
||||
printer(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(c.out, string(raw))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) fetchAndPrintJSON(ctx context.Context, path string) error {
|
||||
data, raw, err := c.fetch(ctx, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if data != nil {
|
||||
return c.writeJSON(data)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(c.out, string(raw))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) writeJSON(data map[string]any) error {
|
||||
enc := json.NewEncoder(c.out)
|
||||
enc.SetIndent("", " ")
|
||||
return enc.Encode(data)
|
||||
}
|
||||
|
||||
func (c *Client) fetch(ctx context.Context, path string) (map[string]any, []byte, error) {
|
||||
fullURL := c.baseURL + path
|
||||
if !strings.Contains(path, "format=json") {
|
||||
if strings.Contains(path, "?") {
|
||||
fullURL += "&format=json"
|
||||
} else {
|
||||
fullURL += "?format=json"
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, nil, fmt.Errorf("server error (%d): %s", resp.StatusCode, strings.TrimSpace(string(body)))
|
||||
}
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
return nil, body, nil
|
||||
}
|
||||
|
||||
return data, body, nil
|
||||
}
|
||||
71
proxy/internal/debug/client_test.go
Normal file
71
proxy/internal/debug/client_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPrintHealth_WithCertsAndClients(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := NewClient("localhost:8444", false, &buf)
|
||||
|
||||
data := map[string]any{
|
||||
"status": "ok",
|
||||
"uptime": "1h30m",
|
||||
"management_connected": true,
|
||||
"all_clients_healthy": true,
|
||||
"certs_total": float64(3),
|
||||
"certs_ready": float64(2),
|
||||
"certs_pending": float64(1),
|
||||
"certs_failed": float64(0),
|
||||
"certs_ready_domains": []any{"a.example.com", "b.example.com"},
|
||||
"certs_pending_domains": []any{"c.example.com"},
|
||||
"clients": map[string]any{
|
||||
"acc-1": map[string]any{
|
||||
"healthy": true,
|
||||
"management_connected": true,
|
||||
"signal_connected": true,
|
||||
"relays_connected": float64(1),
|
||||
"relays_total": float64(2),
|
||||
"peers_connected": float64(3),
|
||||
"peers_total": float64(5),
|
||||
"peers_p2p": float64(2),
|
||||
"peers_relayed": float64(1),
|
||||
"peers_degraded": float64(0),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
c.printHealth(data)
|
||||
out := buf.String()
|
||||
|
||||
assert.Contains(t, out, "Status: ok")
|
||||
assert.Contains(t, out, "Uptime: 1h30m")
|
||||
assert.Contains(t, out, "yes") // management_connected
|
||||
assert.Contains(t, out, "2 ready, 1 pending, 0 failed (3 total)")
|
||||
assert.Contains(t, out, "a.example.com")
|
||||
assert.Contains(t, out, "c.example.com")
|
||||
assert.Contains(t, out, "acc-1")
|
||||
}
|
||||
|
||||
func TestPrintHealth_Minimal(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
c := NewClient("localhost:8444", false, &buf)
|
||||
|
||||
data := map[string]any{
|
||||
"status": "ok",
|
||||
"uptime": "5m",
|
||||
"management_connected": false,
|
||||
"all_clients_healthy": false,
|
||||
}
|
||||
|
||||
c.printHealth(data)
|
||||
out := buf.String()
|
||||
|
||||
assert.Contains(t, out, "Status: ok")
|
||||
assert.Contains(t, out, "Uptime: 5m")
|
||||
assert.NotContains(t, out, "Certificates")
|
||||
assert.NotContains(t, out, "ACCOUNT ID")
|
||||
}
|
||||
712
proxy/internal/debug/handler.go
Normal file
712
proxy/internal/debug/handler.go
Normal file
@@ -0,0 +1,712 @@
|
||||
// Package debug provides HTTP debug endpoints for the proxy server.
|
||||
package debug
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"maps"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
nbembed "github.com/netbirdio/netbird/client/embed"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
//go:embed templates/*.html
|
||||
var templateFS embed.FS
|
||||
|
||||
const defaultPingTimeout = 10 * time.Second
|
||||
|
||||
// formatDuration formats a duration with 2 decimal places using appropriate units.
|
||||
func formatDuration(d time.Duration) string {
|
||||
switch {
|
||||
case d >= time.Hour:
|
||||
return fmt.Sprintf("%.2fh", d.Hours())
|
||||
case d >= time.Minute:
|
||||
return fmt.Sprintf("%.2fm", d.Minutes())
|
||||
case d >= time.Second:
|
||||
return fmt.Sprintf("%.2fs", d.Seconds())
|
||||
case d >= time.Millisecond:
|
||||
return fmt.Sprintf("%.2fms", float64(d.Microseconds())/1000)
|
||||
case d >= time.Microsecond:
|
||||
return fmt.Sprintf("%.2fµs", float64(d.Nanoseconds())/1000)
|
||||
default:
|
||||
return fmt.Sprintf("%dns", d.Nanoseconds())
|
||||
}
|
||||
}
|
||||
|
||||
func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.AccountID {
|
||||
return slices.Sorted(maps.Keys(m))
|
||||
}
|
||||
|
||||
// clientProvider provides access to NetBird clients.
|
||||
type clientProvider interface {
|
||||
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
|
||||
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
|
||||
}
|
||||
|
||||
// healthChecker provides health probe state.
|
||||
type healthChecker interface {
|
||||
ReadinessProbe() bool
|
||||
StartupProbe(ctx context.Context) bool
|
||||
CheckClientsConnected(ctx context.Context) (bool, map[types.AccountID]health.ClientHealth)
|
||||
}
|
||||
|
||||
type certStatus interface {
|
||||
TotalDomains() int
|
||||
PendingDomains() []string
|
||||
ReadyDomains() []string
|
||||
FailedDomains() map[string]string
|
||||
}
|
||||
|
||||
// Handler provides HTTP debug endpoints.
|
||||
type Handler struct {
|
||||
provider clientProvider
|
||||
health healthChecker
|
||||
certStatus certStatus
|
||||
logger *log.Logger
|
||||
startTime time.Time
|
||||
templates *template.Template
|
||||
templateMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewHandler creates a new debug handler.
|
||||
func NewHandler(provider clientProvider, healthChecker healthChecker, logger *log.Logger) *Handler {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
h := &Handler{
|
||||
provider: provider,
|
||||
health: healthChecker,
|
||||
logger: logger,
|
||||
startTime: time.Now(),
|
||||
}
|
||||
if err := h.loadTemplates(); err != nil {
|
||||
logger.Errorf("failed to load embedded templates: %v", err)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
// SetCertStatus sets the certificate status provider for ACME prefetch observability.
|
||||
func (h *Handler) SetCertStatus(cs certStatus) {
|
||||
h.certStatus = cs
|
||||
}
|
||||
|
||||
func (h *Handler) loadTemplates() error {
|
||||
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse embedded templates: %w", err)
|
||||
}
|
||||
|
||||
h.templateMu.Lock()
|
||||
h.templates = tmpl
|
||||
h.templateMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) getTemplates() *template.Template {
|
||||
h.templateMu.RLock()
|
||||
defer h.templateMu.RUnlock()
|
||||
return h.templates
|
||||
}
|
||||
|
||||
// ServeHTTP handles debug requests.
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
path := r.URL.Path
|
||||
wantJSON := r.URL.Query().Get("format") == "json" || strings.HasSuffix(path, "/json")
|
||||
path = strings.TrimSuffix(path, "/json")
|
||||
|
||||
switch path {
|
||||
case "/debug", "/debug/":
|
||||
h.handleIndex(w, r, wantJSON)
|
||||
case "/debug/clients":
|
||||
h.handleListClients(w, r, wantJSON)
|
||||
case "/debug/health":
|
||||
h.handleHealth(w, r, wantJSON)
|
||||
default:
|
||||
if h.handleClientRoutes(w, r, path, wantJSON) {
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientRoutes(w http.ResponseWriter, r *http.Request, path string, wantJSON bool) bool {
|
||||
if !strings.HasPrefix(path, "/debug/clients/") {
|
||||
return false
|
||||
}
|
||||
|
||||
rest := strings.TrimPrefix(path, "/debug/clients/")
|
||||
parts := strings.SplitN(rest, "/", 2)
|
||||
accountID := types.AccountID(parts[0])
|
||||
|
||||
if len(parts) == 1 {
|
||||
h.handleClientStatus(w, r, accountID, wantJSON)
|
||||
return true
|
||||
}
|
||||
|
||||
switch parts[1] {
|
||||
case "syncresponse":
|
||||
h.handleClientSyncResponse(w, r, accountID, wantJSON)
|
||||
case "tools":
|
||||
h.handleClientTools(w, r, accountID)
|
||||
case "pingtcp":
|
||||
h.handlePingTCP(w, r, accountID)
|
||||
case "loglevel":
|
||||
h.handleLogLevel(w, r, accountID)
|
||||
case "start":
|
||||
h.handleClientStart(w, r, accountID)
|
||||
case "stop":
|
||||
h.handleClientStop(w, r, accountID)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type failedDomain struct {
|
||||
Domain string
|
||||
Error string
|
||||
}
|
||||
|
||||
type indexData struct {
|
||||
Version string
|
||||
Uptime string
|
||||
ClientCount int
|
||||
TotalDomains int
|
||||
CertsTotal int
|
||||
CertsReady int
|
||||
CertsPending int
|
||||
CertsFailed int
|
||||
CertsPendingDomains []string
|
||||
CertsReadyDomains []string
|
||||
CertsFailedDomains []failedDomain
|
||||
Clients []clientData
|
||||
}
|
||||
|
||||
type clientData struct {
|
||||
AccountID string
|
||||
Domains string
|
||||
Age string
|
||||
Status string
|
||||
}
|
||||
|
||||
func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
|
||||
clients := h.provider.ListClientsForDebug()
|
||||
sortedIDs := sortedAccountIDs(clients)
|
||||
|
||||
totalDomains := 0
|
||||
for _, info := range clients {
|
||||
totalDomains += info.DomainCount
|
||||
}
|
||||
|
||||
var certsTotal, certsReady, certsPending, certsFailed int
|
||||
var certsPendingDomains, certsReadyDomains []string
|
||||
var certsFailedDomains map[string]string
|
||||
if h.certStatus != nil {
|
||||
certsTotal = h.certStatus.TotalDomains()
|
||||
certsPendingDomains = h.certStatus.PendingDomains()
|
||||
certsReadyDomains = h.certStatus.ReadyDomains()
|
||||
certsFailedDomains = h.certStatus.FailedDomains()
|
||||
certsReady = len(certsReadyDomains)
|
||||
certsPending = len(certsPendingDomains)
|
||||
certsFailed = len(certsFailedDomains)
|
||||
}
|
||||
|
||||
if wantJSON {
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
"account_id": info.AccountID,
|
||||
"domain_count": info.DomainCount,
|
||||
"domains": info.Domains,
|
||||
"has_client": info.HasClient,
|
||||
"created_at": info.CreatedAt,
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
resp := map[string]interface{}{
|
||||
"version": version.NetbirdVersion(),
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
"total_domains": totalDomains,
|
||||
"certs_total": certsTotal,
|
||||
"certs_ready": certsReady,
|
||||
"certs_pending": certsPending,
|
||||
"certs_failed": certsFailed,
|
||||
"clients": clientsJSON,
|
||||
}
|
||||
if len(certsPendingDomains) > 0 {
|
||||
resp["certs_pending_domains"] = certsPendingDomains
|
||||
}
|
||||
if len(certsReadyDomains) > 0 {
|
||||
resp["certs_ready_domains"] = certsReadyDomains
|
||||
}
|
||||
if len(certsFailedDomains) > 0 {
|
||||
resp["certs_failed_domains"] = certsFailedDomains
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
return
|
||||
}
|
||||
|
||||
sortedFailed := make([]failedDomain, 0, len(certsFailedDomains))
|
||||
for d, e := range certsFailedDomains {
|
||||
sortedFailed = append(sortedFailed, failedDomain{Domain: d, Error: e})
|
||||
}
|
||||
slices.SortFunc(sortedFailed, func(a, b failedDomain) int {
|
||||
return cmp.Compare(a.Domain, b.Domain)
|
||||
})
|
||||
|
||||
data := indexData{
|
||||
Version: version.NetbirdVersion(),
|
||||
Uptime: time.Since(h.startTime).Round(time.Second).String(),
|
||||
ClientCount: len(clients),
|
||||
TotalDomains: totalDomains,
|
||||
CertsTotal: certsTotal,
|
||||
CertsReady: certsReady,
|
||||
CertsPending: certsPending,
|
||||
CertsFailed: certsFailed,
|
||||
CertsPendingDomains: certsPendingDomains,
|
||||
CertsReadyDomains: certsReadyDomains,
|
||||
CertsFailedDomains: sortedFailed,
|
||||
Clients: make([]clientData, 0, len(clients)),
|
||||
}
|
||||
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
domains := info.Domains.SafeString()
|
||||
if domains == "" {
|
||||
domains = "-"
|
||||
}
|
||||
status := "No client"
|
||||
if info.HasClient {
|
||||
status = "Active"
|
||||
}
|
||||
data.Clients = append(data.Clients, clientData{
|
||||
AccountID: string(info.AccountID),
|
||||
Domains: domains,
|
||||
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
|
||||
h.renderTemplate(w, "index", data)
|
||||
}
|
||||
|
||||
type clientsData struct {
|
||||
Uptime string
|
||||
Clients []clientData
|
||||
}
|
||||
|
||||
func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, wantJSON bool) {
|
||||
clients := h.provider.ListClientsForDebug()
|
||||
sortedIDs := sortedAccountIDs(clients)
|
||||
|
||||
if wantJSON {
|
||||
clientsJSON := make([]map[string]interface{}, 0, len(clients))
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
clientsJSON = append(clientsJSON, map[string]interface{}{
|
||||
"account_id": info.AccountID,
|
||||
"domain_count": info.DomainCount,
|
||||
"domains": info.Domains,
|
||||
"has_client": info.HasClient,
|
||||
"created_at": info.CreatedAt,
|
||||
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
})
|
||||
}
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"uptime": time.Since(h.startTime).Round(time.Second).String(),
|
||||
"client_count": len(clients),
|
||||
"clients": clientsJSON,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
data := clientsData{
|
||||
Uptime: time.Since(h.startTime).Round(time.Second).String(),
|
||||
Clients: make([]clientData, 0, len(clients)),
|
||||
}
|
||||
|
||||
for _, id := range sortedIDs {
|
||||
info := clients[id]
|
||||
domains := info.Domains.SafeString()
|
||||
if domains == "" {
|
||||
domains = "-"
|
||||
}
|
||||
status := "No client"
|
||||
if info.HasClient {
|
||||
status = "Active"
|
||||
}
|
||||
data.Clients = append(data.Clients, clientData{
|
||||
AccountID: string(info.AccountID),
|
||||
Domains: domains,
|
||||
Age: time.Since(info.CreatedAt).Round(time.Second).String(),
|
||||
Status: status,
|
||||
})
|
||||
}
|
||||
|
||||
h.renderTemplate(w, "clients", data)
|
||||
}
|
||||
|
||||
type clientDetailData struct {
|
||||
AccountID string
|
||||
ActiveTab string
|
||||
Content string
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, accountID types.AccountID, wantJSON bool) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
fullStatus, err := client.Status()
|
||||
if err != nil {
|
||||
http.Error(w, "Error getting status: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse filter parameters
|
||||
query := r.URL.Query()
|
||||
statusFilter := query.Get("filter-by-status")
|
||||
connectionTypeFilter := query.Get("filter-by-connection-type")
|
||||
|
||||
var prefixNamesFilter []string
|
||||
var prefixNamesFilterMap map[string]struct{}
|
||||
if names := query.Get("filter-by-names"); names != "" {
|
||||
prefixNamesFilter = strings.Split(names, ",")
|
||||
prefixNamesFilterMap = make(map[string]struct{})
|
||||
for _, name := range prefixNamesFilter {
|
||||
prefixNamesFilterMap[strings.ToLower(strings.TrimSpace(name))] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var ipsFilterMap map[string]struct{}
|
||||
if ips := query.Get("filter-by-ips"); ips != "" {
|
||||
ipsFilterMap = make(map[string]struct{})
|
||||
for _, ip := range strings.Split(ips, ",") {
|
||||
ipsFilterMap[strings.TrimSpace(ip)] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
pbStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(
|
||||
pbStatus,
|
||||
false,
|
||||
version.NetbirdVersion(),
|
||||
statusFilter,
|
||||
prefixNamesFilter,
|
||||
prefixNamesFilterMap,
|
||||
ipsFilterMap,
|
||||
connectionTypeFilter,
|
||||
"",
|
||||
)
|
||||
|
||||
if wantJSON {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"account_id": accountID,
|
||||
"status": overview.FullDetailSummary(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
data := clientDetailData{
|
||||
AccountID: string(accountID),
|
||||
ActiveTab: "status",
|
||||
Content: overview.FullDetailSummary(),
|
||||
}
|
||||
|
||||
h.renderTemplate(w, "clientDetail", data)
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientSyncResponse(w http.ResponseWriter, _ *http.Request, accountID types.AccountID, wantJSON bool) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
syncResp, err := client.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
http.Error(w, "Error getting sync response: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if syncResp == nil {
|
||||
http.Error(w, "No sync response available for client: "+string(accountID), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
opts := protojson.MarshalOptions{
|
||||
EmitUnpopulated: true,
|
||||
UseProtoNames: true,
|
||||
Indent: " ",
|
||||
AllowPartial: true,
|
||||
}
|
||||
|
||||
jsonBytes, err := opts.Marshal(syncResp)
|
||||
if err != nil {
|
||||
http.Error(w, "Error marshaling sync response: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if wantJSON {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write(jsonBytes)
|
||||
return
|
||||
}
|
||||
|
||||
data := clientDetailData{
|
||||
AccountID: string(accountID),
|
||||
ActiveTab: "syncresponse",
|
||||
Content: string(jsonBytes),
|
||||
}
|
||||
|
||||
h.renderTemplate(w, "clientDetail", data)
|
||||
}
|
||||
|
||||
type toolsData struct {
|
||||
AccountID string
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, accountID types.AccountID) {
|
||||
_, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
http.Error(w, "Client not found: "+string(accountID), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
data := toolsData{
|
||||
AccountID: string(accountID),
|
||||
}
|
||||
|
||||
h.renderTemplate(w, "tools", data)
|
||||
}
|
||||
|
||||
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
host := r.URL.Query().Get("host")
|
||||
portStr := r.URL.Query().Get("port")
|
||||
if host == "" || portStr == "" {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
|
||||
return
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil || port < 1 || port > 65535 {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
|
||||
return
|
||||
}
|
||||
|
||||
timeout := defaultPingTimeout
|
||||
if t := r.URL.Query().Get("timeout"); t != "" {
|
||||
if d, err := time.ParseDuration(t); err == nil {
|
||||
timeout = d
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), timeout)
|
||||
defer cancel()
|
||||
|
||||
address := fmt.Sprintf("%s:%d", host, port)
|
||||
start := time.Now()
|
||||
|
||||
conn, err := client.Dial(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": false,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
h.logger.Debugf("close tcp ping connection: %v", err)
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"host": host,
|
||||
"port": port,
|
||||
"latency_ms": latency.Milliseconds(),
|
||||
"latency": formatDuration(latency),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
level := r.URL.Query().Get("level")
|
||||
if level == "" {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := client.SetLogLevel(level); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"level": level,
|
||||
})
|
||||
}
|
||||
|
||||
const clientActionTimeout = 30 * time.Second
|
||||
|
||||
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(ctx); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "client started",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
|
||||
client, ok := h.provider.GetClient(accountID)
|
||||
if !ok {
|
||||
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), clientActionTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": false,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.writeJSON(w, map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "client stopped",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) {
|
||||
if !wantJSON {
|
||||
http.Redirect(w, r, "/debug", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
uptime := time.Since(h.startTime).Round(10 * time.Millisecond).String()
|
||||
|
||||
ready := h.health.ReadinessProbe()
|
||||
allHealthy, clientHealth := h.health.CheckClientsConnected(r.Context())
|
||||
|
||||
status := "ok"
|
||||
// No clients is not a health issue; only degrade when actual clients are unhealthy
|
||||
if !ready || (!allHealthy && len(clientHealth) > 0) {
|
||||
status = "degraded"
|
||||
}
|
||||
|
||||
var certsTotal, certsReady, certsPending, certsFailed int
|
||||
var certsPendingDomains, certsReadyDomains []string
|
||||
var certsFailedDomains map[string]string
|
||||
if h.certStatus != nil {
|
||||
certsTotal = h.certStatus.TotalDomains()
|
||||
certsPendingDomains = h.certStatus.PendingDomains()
|
||||
certsReadyDomains = h.certStatus.ReadyDomains()
|
||||
certsFailedDomains = h.certStatus.FailedDomains()
|
||||
certsReady = len(certsReadyDomains)
|
||||
certsPending = len(certsPendingDomains)
|
||||
certsFailed = len(certsFailedDomains)
|
||||
}
|
||||
|
||||
resp := map[string]any{
|
||||
"status": status,
|
||||
"uptime": uptime,
|
||||
"management_connected": ready,
|
||||
"all_clients_healthy": allHealthy,
|
||||
"certs_total": certsTotal,
|
||||
"certs_ready": certsReady,
|
||||
"certs_pending": certsPending,
|
||||
"certs_failed": certsFailed,
|
||||
"clients": clientHealth,
|
||||
}
|
||||
if len(certsPendingDomains) > 0 {
|
||||
resp["certs_pending_domains"] = certsPendingDomains
|
||||
}
|
||||
if len(certsReadyDomains) > 0 {
|
||||
resp["certs_ready_domains"] = certsReadyDomains
|
||||
}
|
||||
if len(certsFailedDomains) > 0 {
|
||||
resp["certs_failed_domains"] = certsFailedDomains
|
||||
}
|
||||
h.writeJSON(w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
tmpl := h.getTemplates()
|
||||
if tmpl == nil {
|
||||
http.Error(w, "Templates not loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if err := tmpl.ExecuteTemplate(w, name, data); err != nil {
|
||||
h.logger.Errorf("execute template %s: %v", name, err)
|
||||
http.Error(w, "Template error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(v); err != nil {
|
||||
h.logger.Errorf("encode JSON response: %v", err)
|
||||
}
|
||||
}
|
||||
101
proxy/internal/debug/templates/base.html
Normal file
101
proxy/internal/debug/templates/base.html
Normal file
@@ -0,0 +1,101 @@
|
||||
{{define "style"}}
|
||||
body {
|
||||
font-family: monospace;
|
||||
margin: 20px;
|
||||
background: #1a1a1a;
|
||||
color: #eee;
|
||||
}
|
||||
a {
|
||||
color: #6cf;
|
||||
}
|
||||
h1, h2, h3 {
|
||||
color: #fff;
|
||||
}
|
||||
.info {
|
||||
color: #aaa;
|
||||
}
|
||||
table {
|
||||
border-collapse: collapse;
|
||||
margin: 10px 0;
|
||||
}
|
||||
th, td {
|
||||
border: 1px solid #444;
|
||||
padding: 8px;
|
||||
text-align: left;
|
||||
}
|
||||
th {
|
||||
background: #333;
|
||||
}
|
||||
.nav {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.nav a {
|
||||
margin-right: 15px;
|
||||
padding: 8px 16px;
|
||||
background: #333;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.nav a.active {
|
||||
background: #6cf;
|
||||
color: #000;
|
||||
}
|
||||
pre {
|
||||
background: #222;
|
||||
padding: 15px;
|
||||
border-radius: 4px;
|
||||
overflow-x: auto;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
input, select, textarea {
|
||||
background: #333;
|
||||
color: #eee;
|
||||
border: 1px solid #555;
|
||||
padding: 8px;
|
||||
border-radius: 4px;
|
||||
font-family: monospace;
|
||||
}
|
||||
input:focus, select:focus, textarea:focus {
|
||||
outline: none;
|
||||
border-color: #6cf;
|
||||
}
|
||||
button {
|
||||
background: #6cf;
|
||||
color: #000;
|
||||
border: none;
|
||||
padding: 8px 16px;
|
||||
border-radius: 4px;
|
||||
cursor: pointer;
|
||||
font-family: monospace;
|
||||
}
|
||||
button:hover {
|
||||
background: #5be;
|
||||
}
|
||||
button:disabled {
|
||||
background: #555;
|
||||
color: #888;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
.form-group {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.form-group label {
|
||||
display: block;
|
||||
margin-bottom: 5px;
|
||||
color: #aaa;
|
||||
}
|
||||
.form-row {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
align-items: flex-end;
|
||||
}
|
||||
.result {
|
||||
margin-top: 20px;
|
||||
}
|
||||
.success {
|
||||
color: #5f5;
|
||||
}
|
||||
.error {
|
||||
color: #f55;
|
||||
}
|
||||
{{end}}
|
||||
19
proxy/internal/debug/templates/client_detail.html
Normal file
19
proxy/internal/debug/templates/client_detail.html
Normal file
@@ -0,0 +1,19 @@
|
||||
{{define "clientDetail"}}
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>Client {{.AccountID}}</title>
|
||||
<style>{{template "style"}}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Client: {{.AccountID}}</h1>
|
||||
<div class="nav">
|
||||
<a href="/debug">← Back</a>
|
||||
<a href="/debug/clients/{{.AccountID}}/tools"{{if eq .ActiveTab "tools"}} class="active"{{end}}>Tools</a>
|
||||
<a href="/debug/clients/{{.AccountID}}"{{if eq .ActiveTab "status"}} class="active"{{end}}>Status</a>
|
||||
<a href="/debug/clients/{{.AccountID}}/syncresponse"{{if eq .ActiveTab "syncresponse"}} class="active"{{end}}>Sync Response</a>
|
||||
</div>
|
||||
<pre>{{.Content}}</pre>
|
||||
</body>
|
||||
</html>
|
||||
{{end}}
|
||||
33
proxy/internal/debug/templates/clients.html
Normal file
33
proxy/internal/debug/templates/clients.html
Normal file
@@ -0,0 +1,33 @@
|
||||
{{define "clients"}}
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>Clients</title>
|
||||
<style>{{template "style"}}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>All Clients</h1>
|
||||
<p class="info">Uptime: {{.Uptime}} | <a href="/debug">← Back</a></p>
|
||||
{{if .Clients}}
|
||||
<table>
|
||||
<tr>
|
||||
<th>Account ID</th>
|
||||
<th>Domains</th>
|
||||
<th>Age</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
{{range .Clients}}
|
||||
<tr>
|
||||
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
|
||||
<td>{{.Domains}}</td>
|
||||
<td>{{.Age}}</td>
|
||||
<td>{{.Status}}</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</table>
|
||||
{{else}}
|
||||
<p>No clients connected</p>
|
||||
{{end}}
|
||||
</body>
|
||||
</html>
|
||||
{{end}}
|
||||
58
proxy/internal/debug/templates/index.html
Normal file
58
proxy/internal/debug/templates/index.html
Normal file
@@ -0,0 +1,58 @@
|
||||
{{define "index"}}
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>NetBird Proxy Debug</title>
|
||||
<style>{{template "style"}}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>NetBird Proxy Debug</h1>
|
||||
<p class="info">Version: {{.Version}} | Uptime: {{.Uptime}}</p>
|
||||
<h2>Certificates: {{.CertsReady}} ready, {{.CertsPending}} pending, {{.CertsFailed}} failed ({{.CertsTotal}} total)</h2>
|
||||
{{if .CertsReadyDomains}}
|
||||
<details>
|
||||
<summary>Ready domains ({{.CertsReady}})</summary>
|
||||
<ul>{{range .CertsReadyDomains}}<li>{{.}}</li>{{end}}</ul>
|
||||
</details>
|
||||
{{end}}
|
||||
{{if .CertsPendingDomains}}
|
||||
<details open>
|
||||
<summary>Pending domains ({{.CertsPending}})</summary>
|
||||
<ul>{{range .CertsPendingDomains}}<li>{{.}}</li>{{end}}</ul>
|
||||
</details>
|
||||
{{end}}
|
||||
{{if .CertsFailedDomains}}
|
||||
<details open>
|
||||
<summary>Failed domains ({{.CertsFailed}})</summary>
|
||||
<ul>{{range .CertsFailedDomains}}<li>{{.Domain}}: {{.Error}}</li>{{end}}</ul>
|
||||
</details>
|
||||
{{end}}
|
||||
<h2>Clients ({{.ClientCount}}) | Domains ({{.TotalDomains}})</h2>
|
||||
{{if .Clients}}
|
||||
<table>
|
||||
<tr>
|
||||
<th>Account ID</th>
|
||||
<th>Domains</th>
|
||||
<th>Age</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
{{range .Clients}}
|
||||
<tr>
|
||||
<td><a href="/debug/clients/{{.AccountID}}/tools">{{.AccountID}}</a></td>
|
||||
<td>{{.Domains}}</td>
|
||||
<td>{{.Age}}</td>
|
||||
<td>{{.Status}}</td>
|
||||
</tr>
|
||||
{{end}}
|
||||
</table>
|
||||
{{else}}
|
||||
<p>No clients connected</p>
|
||||
{{end}}
|
||||
<h2>Endpoints</h2>
|
||||
<ul>
|
||||
<li><a href="/debug/clients">/debug/clients</a> - all clients detail</li>
|
||||
</ul>
|
||||
<p class="info">Add ?format=json or /json suffix for JSON output</p>
|
||||
</body>
|
||||
</html>
|
||||
{{end}}
|
||||
142
proxy/internal/debug/templates/tools.html
Normal file
142
proxy/internal/debug/templates/tools.html
Normal file
@@ -0,0 +1,142 @@
|
||||
{{define "tools"}}
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>Client {{.AccountID}} - Tools</title>
|
||||
<style>{{template "style"}}</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Client: {{.AccountID}}</h1>
|
||||
<div class="nav">
|
||||
<a href="/debug">← Back</a>
|
||||
<a href="/debug/clients/{{.AccountID}}/tools" class="active">Tools</a>
|
||||
<a href="/debug/clients/{{.AccountID}}">Status</a>
|
||||
<a href="/debug/clients/{{.AccountID}}/syncresponse">Sync Response</a>
|
||||
</div>
|
||||
|
||||
<h2>Client Control</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<button onclick="startClient()">Start</button>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<button onclick="stopClient()">Stop</button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="client-result" class="result"></div>
|
||||
|
||||
<h2>Log Level</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<label for="log-level">Level</label>
|
||||
<select id="log-level" style="width: 120px;">
|
||||
<option value="trace">trace</option>
|
||||
<option value="debug">debug</option>
|
||||
<option value="info">info</option>
|
||||
<option value="warn" selected>warn</option>
|
||||
<option value="error">error</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<button onclick="setLogLevel()">Set Level</button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="log-result" class="result"></div>
|
||||
|
||||
<h2>TCP Ping</h2>
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<label for="tcp-host">Host</label>
|
||||
<input type="text" id="tcp-host" placeholder="100.0.0.1 or hostname.netbird.cloud" style="width: 300px;">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="tcp-port">Port</label>
|
||||
<input type="number" id="tcp-port" placeholder="80" style="width: 80px;">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<span> </span>
|
||||
<button onclick="doTcpPing()">Connect</button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="tcp-result" class="result"></div>
|
||||
|
||||
<script>
|
||||
const accountID = "{{.AccountID}}";
|
||||
|
||||
async function startClient() {
|
||||
const resultDiv = document.getElementById('client-result');
|
||||
resultDiv.innerHTML = '<span class="info">Starting client...</span>';
|
||||
try {
|
||||
const resp = await fetch('/debug/clients/' + accountID + '/start');
|
||||
const data = await resp.json();
|
||||
if (data.success) {
|
||||
resultDiv.innerHTML = '<span class="success">✓ ' + data.message + '</span>';
|
||||
} else {
|
||||
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
|
||||
}
|
||||
} catch (e) {
|
||||
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||
}
|
||||
}
|
||||
|
||||
async function stopClient() {
|
||||
const resultDiv = document.getElementById('client-result');
|
||||
resultDiv.innerHTML = '<span class="info">Stopping client...</span>';
|
||||
try {
|
||||
const resp = await fetch('/debug/clients/' + accountID + '/stop');
|
||||
const data = await resp.json();
|
||||
if (data.success) {
|
||||
resultDiv.innerHTML = '<span class="success">✓ ' + data.message + '</span>';
|
||||
} else {
|
||||
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
|
||||
}
|
||||
} catch (e) {
|
||||
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||
}
|
||||
}
|
||||
|
||||
async function setLogLevel() {
|
||||
const level = document.getElementById('log-level').value;
|
||||
const resultDiv = document.getElementById('log-result');
|
||||
resultDiv.innerHTML = '<span class="info">Setting log level...</span>';
|
||||
try {
|
||||
const resp = await fetch('/debug/clients/' + accountID + '/loglevel?level=' + level);
|
||||
const data = await resp.json();
|
||||
if (data.success) {
|
||||
resultDiv.innerHTML = '<span class="success">✓ Log level set to: ' + data.level + '</span>';
|
||||
} else {
|
||||
resultDiv.innerHTML = '<span class="error">✗ ' + data.error + '</span>';
|
||||
}
|
||||
} catch (e) {
|
||||
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||
}
|
||||
}
|
||||
|
||||
async function doTcpPing() {
|
||||
const host = document.getElementById('tcp-host').value;
|
||||
const port = document.getElementById('tcp-port').value;
|
||||
if (!host || !port) {
|
||||
alert('Host and port required');
|
||||
return;
|
||||
}
|
||||
const resultDiv = document.getElementById('tcp-result');
|
||||
resultDiv.innerHTML = '<span class="info">Connecting...</span>';
|
||||
try {
|
||||
const resp = await fetch('/debug/clients/' + accountID + '/pingtcp?host=' + encodeURIComponent(host) + '&port=' + port);
|
||||
const data = await resp.json();
|
||||
if (data.success) {
|
||||
resultDiv.innerHTML = '<span class="success">✓ ' + data.host + ':' + data.port + ' connected in ' + data.latency + '</span>';
|
||||
} else {
|
||||
resultDiv.innerHTML = '<span class="error">✗ ' + data.host + ':' + data.port + ': ' + data.error + '</span>';
|
||||
}
|
||||
} catch (e) {
|
||||
resultDiv.innerHTML = '<span class="error">Error: ' + e.message + '</span>';
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
{{end}}
|
||||
20
proxy/internal/flock/flock_other.go
Normal file
20
proxy/internal/flock/flock_other.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build !unix
|
||||
|
||||
package flock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Lock is a no-op on non-Unix platforms. Returns (nil, nil) to indicate
|
||||
// that no lock was acquired; callers must treat a nil file as "proceed
|
||||
// without lock" rather than "lock held by someone else."
|
||||
func Lock(_ context.Context, _ string) (*os.File, error) {
|
||||
return nil, nil //nolint:nilnil // intentional: nil file signals locking unsupported on this platform
|
||||
}
|
||||
|
||||
// Unlock is a no-op on non-Unix platforms.
|
||||
func Unlock(_ *os.File) error {
|
||||
return nil
|
||||
}
|
||||
79
proxy/internal/flock/flock_test.go
Normal file
79
proxy/internal/flock/flock_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
//go:build unix
|
||||
|
||||
package flock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLockUnlock(t *testing.T) {
|
||||
lockPath := filepath.Join(t.TempDir(), "test.lock")
|
||||
|
||||
f, err := Lock(context.Background(), lockPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, f)
|
||||
|
||||
_, err = os.Stat(lockPath)
|
||||
assert.NoError(t, err, "lock file should exist")
|
||||
|
||||
err = Unlock(f)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUnlockNil(t *testing.T) {
|
||||
err := Unlock(nil)
|
||||
assert.NoError(t, err, "unlocking nil should be a no-op")
|
||||
}
|
||||
|
||||
func TestLockRespectsContext(t *testing.T) {
|
||||
lockPath := filepath.Join(t.TempDir(), "test.lock")
|
||||
|
||||
f1, err := Lock(context.Background(), lockPath)
|
||||
require.NoError(t, err)
|
||||
defer func() { require.NoError(t, Unlock(f1)) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = Lock(ctx, lockPath)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestLockBlocks(t *testing.T) {
|
||||
lockPath := filepath.Join(t.TempDir(), "test.lock")
|
||||
|
||||
f1, err := Lock(context.Background(), lockPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
start := time.Now()
|
||||
var elapsed time.Duration
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
f2, err := Lock(context.Background(), lockPath)
|
||||
elapsed = time.Since(start)
|
||||
assert.NoError(t, err)
|
||||
if f2 != nil {
|
||||
assert.NoError(t, Unlock(f2))
|
||||
}
|
||||
}()
|
||||
|
||||
// Hold the lock for 200ms, then release.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
require.NoError(t, Unlock(f1))
|
||||
|
||||
wg.Wait()
|
||||
assert.GreaterOrEqual(t, elapsed, 150*time.Millisecond,
|
||||
"Lock should have blocked for at least ~200ms")
|
||||
}
|
||||
77
proxy/internal/flock/flock_unix.go
Normal file
77
proxy/internal/flock/flock_unix.go
Normal file
@@ -0,0 +1,77 @@
|
||||
//go:build unix
|
||||
|
||||
// Package flock provides best-effort advisory file locking using flock(2).
|
||||
//
|
||||
// This is used for cross-replica coordination (e.g. preventing duplicate
|
||||
// ACME requests). Note that flock(2) does NOT work reliably on NFS volumes:
|
||||
// on NFSv3 it depends on the NLM daemon, on NFSv4 Linux emulates it via
|
||||
// fcntl locks with different semantics. Callers must treat lock failures
|
||||
// as non-fatal and proceed without the lock.
|
||||
package flock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const retryInterval = 100 * time.Millisecond
|
||||
|
||||
// Lock acquires an exclusive advisory lock on the given file path.
|
||||
// It creates the lock file if it does not exist. The lock attempt
|
||||
// respects context cancellation by using non-blocking flock with polling.
|
||||
// The caller must call Unlock with the returned *os.File when done.
|
||||
func Lock(ctx context.Context, path string) (*os.File, error) {
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open lock file %s: %w", path, err)
|
||||
}
|
||||
|
||||
timer := time.NewTimer(retryInterval)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err == nil {
|
||||
return f, nil
|
||||
} else if !errors.Is(err, syscall.EWOULDBLOCK) {
|
||||
if cerr := f.Close(); cerr != nil {
|
||||
log.Debugf("close lock file %s: %v", path, cerr)
|
||||
}
|
||||
return nil, fmt.Errorf("acquire lock on %s: %w", path, err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if cerr := f.Close(); cerr != nil {
|
||||
log.Debugf("close lock file %s: %v", path, cerr)
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
case <-timer.C:
|
||||
timer.Reset(retryInterval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unlock releases the lock and closes the file.
|
||||
func Unlock(f *os.File) error {
|
||||
if f == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if cerr := f.Close(); cerr != nil {
|
||||
log.Debugf("close lock file: %v", cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil {
|
||||
return fmt.Errorf("release lock: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
48
proxy/internal/grpc/auth.go
Normal file
48
proxy/internal/grpc/auth.go
Normal file
@@ -0,0 +1,48 @@
|
||||
// Package grpc provides gRPC utilities for the proxy client.
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
// EnvProxyAllowInsecure controls whether the proxy token can be sent over non-TLS connections.
|
||||
const EnvProxyAllowInsecure = "NB_PROXY_ALLOW_INSECURE"
|
||||
|
||||
var _ credentials.PerRPCCredentials = (*proxyAuthToken)(nil)
|
||||
|
||||
type proxyAuthToken struct {
|
||||
token string
|
||||
allowInsecure bool
|
||||
}
|
||||
|
||||
func (t proxyAuthToken) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
|
||||
return map[string]string{
|
||||
"authorization": "Bearer " + t.token,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RequireTransportSecurity returns true by default to protect the token in transit.
|
||||
// Set NB_PROXY_ALLOW_INSECURE=true to allow non-TLS connections (not recommended for production).
|
||||
func (t proxyAuthToken) RequireTransportSecurity() bool {
|
||||
return !t.allowInsecure
|
||||
}
|
||||
|
||||
// WithProxyToken returns a DialOption that sets the proxy access token on each outbound RPC.
|
||||
func WithProxyToken(token string) grpc.DialOption {
|
||||
allowInsecure := false
|
||||
if val := os.Getenv(EnvProxyAllowInsecure); val != "" {
|
||||
parsed, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("invalid value for %s: %v", EnvProxyAllowInsecure, err)
|
||||
} else {
|
||||
allowInsecure = parsed
|
||||
}
|
||||
}
|
||||
return grpc.WithPerRPCCredentials(proxyAuthToken{token: token, allowInsecure: allowInsecure})
|
||||
}
|
||||
405
proxy/internal/health/health.go
Normal file
405
proxy/internal/health/health.go
Normal file
@@ -0,0 +1,405 @@
|
||||
// Package health provides health probes for the proxy server.
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
const handshakeStaleThreshold = 5 * time.Minute
|
||||
|
||||
const (
|
||||
maxConcurrentChecks = 3
|
||||
maxClientCheckTimeout = 5 * time.Minute
|
||||
)
|
||||
|
||||
// clientProvider provides access to NetBird clients for health checks.
|
||||
type clientProvider interface {
|
||||
ListClientsForStartup() map[types.AccountID]*embed.Client
|
||||
}
|
||||
|
||||
// Checker tracks health state and provides probe endpoints.
|
||||
type Checker struct {
|
||||
logger *log.Logger
|
||||
provider clientProvider
|
||||
|
||||
mu sync.RWMutex
|
||||
managementConnected bool
|
||||
initialSyncComplete bool
|
||||
shuttingDown bool
|
||||
|
||||
// checkSem limits concurrent client health checks.
|
||||
checkSem chan struct{}
|
||||
|
||||
// checkHealth checks the health of a single client.
|
||||
// Defaults to checkClientHealth; overridable in tests.
|
||||
checkHealth func(*embed.Client) ClientHealth
|
||||
}
|
||||
|
||||
// ClientHealth represents the health status of a single NetBird client.
|
||||
type ClientHealth struct {
|
||||
Healthy bool `json:"healthy"`
|
||||
ManagementConnected bool `json:"management_connected"`
|
||||
SignalConnected bool `json:"signal_connected"`
|
||||
RelaysConnected int `json:"relays_connected"`
|
||||
RelaysTotal int `json:"relays_total"`
|
||||
PeersTotal int `json:"peers_total"`
|
||||
PeersConnected int `json:"peers_connected"`
|
||||
PeersP2P int `json:"peers_p2p"`
|
||||
PeersRelayed int `json:"peers_relayed"`
|
||||
PeersDegraded int `json:"peers_degraded"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ProbeResponse represents the JSON response for health probes.
|
||||
type ProbeResponse struct {
|
||||
Status string `json:"status"`
|
||||
Checks map[string]bool `json:"checks,omitempty"`
|
||||
Clients map[types.AccountID]ClientHealth `json:"clients,omitempty"`
|
||||
}
|
||||
|
||||
// Server runs the health probe HTTP server on a dedicated port.
|
||||
type Server struct {
|
||||
server *http.Server
|
||||
logger *log.Logger
|
||||
checker *Checker
|
||||
}
|
||||
|
||||
// SetManagementConnected updates the management connection state.
|
||||
func (c *Checker) SetManagementConnected(connected bool) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.managementConnected = connected
|
||||
}
|
||||
|
||||
// SetInitialSyncComplete marks that the initial mapping sync has completed.
|
||||
func (c *Checker) SetInitialSyncComplete() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.initialSyncComplete = true
|
||||
}
|
||||
|
||||
// SetShuttingDown marks the server as shutting down.
|
||||
// This causes ReadinessProbe to return false so load balancers stop routing traffic.
|
||||
func (c *Checker) SetShuttingDown() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.shuttingDown = true
|
||||
}
|
||||
|
||||
// CheckClientsConnected verifies all clients are connected to management/signal/relay.
|
||||
// Uses the provided context for timeout/cancellation, with a maximum bound of maxClientCheckTimeout.
|
||||
// Limits concurrent checks via semaphore.
|
||||
func (c *Checker) CheckClientsConnected(ctx context.Context) (bool, map[types.AccountID]ClientHealth) {
|
||||
// Apply upper bound timeout in case parent context has no deadline
|
||||
ctx, cancel := context.WithTimeout(ctx, maxClientCheckTimeout)
|
||||
defer cancel()
|
||||
|
||||
clients := c.provider.ListClientsForStartup()
|
||||
|
||||
// No clients is not a health issue
|
||||
if len(clients) == 0 {
|
||||
return true, make(map[types.AccountID]ClientHealth)
|
||||
}
|
||||
|
||||
type result struct {
|
||||
accountID types.AccountID
|
||||
health ClientHealth
|
||||
}
|
||||
|
||||
resultsCh := make(chan result, len(clients))
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for accountID, client := range clients {
|
||||
wg.Add(1)
|
||||
go func(id types.AccountID, cl *embed.Client) {
|
||||
defer wg.Done()
|
||||
|
||||
// Acquire semaphore
|
||||
select {
|
||||
case c.checkSem <- struct{}{}:
|
||||
defer func() { <-c.checkSem }()
|
||||
case <-ctx.Done():
|
||||
resultsCh <- result{id, ClientHealth{Healthy: false, Error: ctx.Err().Error()}}
|
||||
return
|
||||
}
|
||||
|
||||
resultsCh <- result{id, c.checkHealth(cl)}
|
||||
}(accountID, client)
|
||||
}
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultsCh)
|
||||
}()
|
||||
|
||||
results := make(map[types.AccountID]ClientHealth)
|
||||
allHealthy := true
|
||||
for r := range resultsCh {
|
||||
results[r.accountID] = r.health
|
||||
if !r.health.Healthy {
|
||||
allHealthy = false
|
||||
}
|
||||
}
|
||||
|
||||
return allHealthy, results
|
||||
}
|
||||
|
||||
// LivenessProbe returns true if the process is alive.
|
||||
// This should always return true if we can respond.
|
||||
func (c *Checker) LivenessProbe() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ReadinessProbe returns true if the server can accept traffic.
|
||||
func (c *Checker) ReadinessProbe() bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
if c.shuttingDown {
|
||||
return false
|
||||
}
|
||||
return c.managementConnected
|
||||
}
|
||||
|
||||
// StartupProbe checks if initial startup is complete.
|
||||
// Checks management connection, initial sync, and all client health directly.
|
||||
// Uses the provided context for timeout/cancellation.
|
||||
func (c *Checker) StartupProbe(ctx context.Context) bool {
|
||||
c.mu.RLock()
|
||||
mgmt := c.managementConnected
|
||||
sync := c.initialSyncComplete
|
||||
c.mu.RUnlock()
|
||||
|
||||
if !mgmt || !sync {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check all clients are connected to management/signal/relay.
|
||||
// Returns true when no clients exist (nothing to check).
|
||||
allHealthy, _ := c.CheckClientsConnected(ctx)
|
||||
return allHealthy
|
||||
}
|
||||
|
||||
// Handler returns an http.Handler for health probe endpoints.
|
||||
func (c *Checker) Handler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/healthz/live", c.handleLiveness)
|
||||
mux.HandleFunc("/healthz/ready", c.handleReadiness)
|
||||
mux.HandleFunc("/healthz/startup", c.handleStartup)
|
||||
mux.HandleFunc("/healthz", c.handleFull)
|
||||
return mux
|
||||
}
|
||||
|
||||
func (c *Checker) handleLiveness(w http.ResponseWriter, r *http.Request) {
|
||||
if c.LivenessProbe() {
|
||||
c.writeProbeResponse(w, http.StatusOK, "ok", nil, nil)
|
||||
return
|
||||
}
|
||||
c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", nil, nil)
|
||||
}
|
||||
|
||||
func (c *Checker) handleReadiness(w http.ResponseWriter, r *http.Request) {
|
||||
c.mu.RLock()
|
||||
checks := map[string]bool{
|
||||
"management_connected": c.managementConnected,
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
if c.ReadinessProbe() {
|
||||
c.writeProbeResponse(w, http.StatusOK, "ok", checks, nil)
|
||||
return
|
||||
}
|
||||
c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", checks, nil)
|
||||
}
|
||||
|
||||
func (c *Checker) handleStartup(w http.ResponseWriter, r *http.Request) {
|
||||
c.mu.RLock()
|
||||
mgmt := c.managementConnected
|
||||
syncComplete := c.initialSyncComplete
|
||||
c.mu.RUnlock()
|
||||
|
||||
allClientsHealthy, clientHealth := c.CheckClientsConnected(r.Context())
|
||||
|
||||
checks := map[string]bool{
|
||||
"management_connected": mgmt,
|
||||
"initial_sync_complete": syncComplete,
|
||||
"all_clients_healthy": allClientsHealthy,
|
||||
}
|
||||
|
||||
ready := mgmt && syncComplete && allClientsHealthy
|
||||
if ready {
|
||||
c.writeProbeResponse(w, http.StatusOK, "ok", checks, clientHealth)
|
||||
return
|
||||
}
|
||||
c.writeProbeResponse(w, http.StatusServiceUnavailable, "fail", checks, clientHealth)
|
||||
}
|
||||
|
||||
func (c *Checker) handleFull(w http.ResponseWriter, r *http.Request) {
|
||||
c.mu.RLock()
|
||||
mgmt := c.managementConnected
|
||||
sync := c.initialSyncComplete
|
||||
c.mu.RUnlock()
|
||||
|
||||
allClientsHealthy, clientHealth := c.CheckClientsConnected(r.Context())
|
||||
|
||||
checks := map[string]bool{
|
||||
"management_connected": mgmt,
|
||||
"initial_sync_complete": sync,
|
||||
"all_clients_healthy": allClientsHealthy,
|
||||
}
|
||||
|
||||
status := "ok"
|
||||
statusCode := http.StatusOK
|
||||
if !c.ReadinessProbe() {
|
||||
status = "fail"
|
||||
statusCode = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
c.writeProbeResponse(w, statusCode, status, checks, clientHealth)
|
||||
}
|
||||
|
||||
func (c *Checker) writeProbeResponse(w http.ResponseWriter, statusCode int, status string, checks map[string]bool, clients map[types.AccountID]ClientHealth) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
resp := ProbeResponse{
|
||||
Status: status,
|
||||
Checks: checks,
|
||||
Clients: clients,
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
c.logger.Debugf("write health response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ListenAndServe starts the health probe server.
|
||||
func (s *Server) ListenAndServe() error {
|
||||
s.logger.Infof("starting health probe server on %s", s.server.Addr)
|
||||
return s.server.ListenAndServe()
|
||||
}
|
||||
|
||||
// Serve starts the health probe server on the given listener.
|
||||
func (s *Server) Serve(l net.Listener) error {
|
||||
s.logger.Infof("starting health probe server on %s", l.Addr())
|
||||
return s.server.Serve(l)
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the health probe server.
|
||||
func (s *Server) Shutdown(ctx context.Context) error {
|
||||
return s.server.Shutdown(ctx)
|
||||
}
|
||||
|
||||
// NewChecker creates a new health checker.
|
||||
func NewChecker(logger *log.Logger, provider clientProvider) *Checker {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &Checker{
|
||||
logger: logger,
|
||||
provider: provider,
|
||||
checkSem: make(chan struct{}, maxConcurrentChecks),
|
||||
checkHealth: checkClientHealth,
|
||||
}
|
||||
}
|
||||
|
||||
// NewServer creates a new health probe server.
|
||||
// If metricsHandler is non-nil, it is mounted at /metrics on the same port.
|
||||
func NewServer(addr string, checker *Checker, logger *log.Logger, metricsHandler http.Handler) *Server {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
|
||||
handler := checker.Handler()
|
||||
if metricsHandler != nil {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/metrics", metricsHandler)
|
||||
mux.Handle("/", handler)
|
||||
handler = mux
|
||||
}
|
||||
|
||||
return &Server{
|
||||
server: &http.Server{
|
||||
Addr: addr,
|
||||
Handler: handler,
|
||||
ReadTimeout: 5 * time.Second,
|
||||
WriteTimeout: 5 * time.Second,
|
||||
},
|
||||
logger: logger,
|
||||
checker: checker,
|
||||
}
|
||||
}
|
||||
|
||||
func checkClientHealth(client *embed.Client) ClientHealth {
|
||||
if client == nil {
|
||||
return ClientHealth{
|
||||
Healthy: false,
|
||||
Error: "client not initialized",
|
||||
}
|
||||
}
|
||||
|
||||
status, err := client.Status()
|
||||
if err != nil {
|
||||
return ClientHealth{
|
||||
Healthy: false,
|
||||
Error: err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// Count only rel:// and rels:// relays (not stun/turn)
|
||||
var relayCount, relaysConnected int
|
||||
for _, relay := range status.Relays {
|
||||
if !strings.HasPrefix(relay.URI, "rel://") && !strings.HasPrefix(relay.URI, "rels://") {
|
||||
continue
|
||||
}
|
||||
relayCount++
|
||||
if relay.Err == nil {
|
||||
relaysConnected++
|
||||
}
|
||||
}
|
||||
|
||||
// Count peer connection stats
|
||||
now := time.Now()
|
||||
var peersConnected, peersP2P, peersRelayed, peersDegraded int
|
||||
for _, p := range status.Peers {
|
||||
if p.ConnStatus != embed.PeerStatusConnected {
|
||||
continue
|
||||
}
|
||||
peersConnected++
|
||||
if p.Relayed {
|
||||
peersRelayed++
|
||||
} else {
|
||||
peersP2P++
|
||||
}
|
||||
if p.LastWireguardHandshake.IsZero() || now.Sub(p.LastWireguardHandshake) > handshakeStaleThreshold {
|
||||
peersDegraded++
|
||||
}
|
||||
}
|
||||
|
||||
// Client is healthy if connected to management, signal, and at least one relay (if any are defined)
|
||||
healthy := status.ManagementState.Connected &&
|
||||
status.SignalState.Connected &&
|
||||
(relayCount == 0 || relaysConnected > 0)
|
||||
|
||||
return ClientHealth{
|
||||
Healthy: healthy,
|
||||
ManagementConnected: status.ManagementState.Connected,
|
||||
SignalConnected: status.SignalState.Connected,
|
||||
RelaysConnected: relaysConnected,
|
||||
RelaysTotal: relayCount,
|
||||
PeersTotal: len(status.Peers),
|
||||
PeersConnected: peersConnected,
|
||||
PeersP2P: peersP2P,
|
||||
PeersRelayed: peersRelayed,
|
||||
PeersDegraded: peersDegraded,
|
||||
}
|
||||
}
|
||||
473
proxy/internal/health/health_test.go
Normal file
473
proxy/internal/health/health_test.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
type mockClientProvider struct {
|
||||
clients map[types.AccountID]*embed.Client
|
||||
}
|
||||
|
||||
func (m *mockClientProvider) ListClientsForStartup() map[types.AccountID]*embed.Client {
|
||||
return m.clients
|
||||
}
|
||||
|
||||
// newTestChecker creates a checker with a mock health function for testing.
|
||||
// The health function returns the provided ClientHealth for every client.
|
||||
func newTestChecker(provider clientProvider, healthResult ClientHealth) *Checker {
|
||||
c := NewChecker(nil, provider)
|
||||
c.checkHealth = func(_ *embed.Client) ClientHealth {
|
||||
return healthResult
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func TestChecker_LivenessProbe(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
|
||||
// Liveness should always return true if we can respond.
|
||||
assert.True(t, checker.LivenessProbe())
|
||||
}
|
||||
|
||||
func TestChecker_ReadinessProbe(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
|
||||
// Initially not ready (management not connected).
|
||||
assert.False(t, checker.ReadinessProbe())
|
||||
|
||||
// After management connects, should be ready.
|
||||
checker.SetManagementConnected(true)
|
||||
assert.True(t, checker.ReadinessProbe())
|
||||
|
||||
// If management disconnects, should not be ready.
|
||||
checker.SetManagementConnected(false)
|
||||
assert.False(t, checker.ReadinessProbe())
|
||||
}
|
||||
|
||||
// TestStartupProbe_EmptyServiceList covers the scenario where management has
|
||||
// no services configured for this proxy. The proxy should become ready once
|
||||
// management is connected and the initial sync completes, even with zero clients.
|
||||
func TestStartupProbe_EmptyServiceList(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
|
||||
// No management connection = not ready.
|
||||
assert.False(t, checker.StartupProbe(context.Background()))
|
||||
|
||||
// Management connected but no sync = not ready.
|
||||
checker.SetManagementConnected(true)
|
||||
assert.False(t, checker.StartupProbe(context.Background()))
|
||||
|
||||
// Management + sync complete + no clients = ready.
|
||||
checker.SetInitialSyncComplete()
|
||||
assert.True(t, checker.StartupProbe(context.Background()))
|
||||
}
|
||||
|
||||
// TestStartupProbe_WithUnhealthyClients verifies that when services exist
|
||||
// and clients have been created but are not yet fully connected (to mgmt,
|
||||
// signal, relays), the startup probe does NOT pass.
|
||||
func TestStartupProbe_WithUnhealthyClients(t *testing.T) {
|
||||
provider := &mockClientProvider{
|
||||
clients: map[types.AccountID]*embed.Client{
|
||||
"account-1": nil, // concrete client not needed; checkHealth is mocked
|
||||
"account-2": nil,
|
||||
},
|
||||
}
|
||||
checker := newTestChecker(provider, ClientHealth{Healthy: false, Error: "not connected yet"})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetInitialSyncComplete()
|
||||
|
||||
assert.False(t, checker.StartupProbe(context.Background()),
|
||||
"startup probe must not pass when clients are unhealthy")
|
||||
}
|
||||
|
||||
// TestStartupProbe_WithHealthyClients verifies that once all clients are
|
||||
// connected and healthy, the startup probe passes.
|
||||
func TestStartupProbe_WithHealthyClients(t *testing.T) {
|
||||
provider := &mockClientProvider{
|
||||
clients: map[types.AccountID]*embed.Client{
|
||||
"account-1": nil,
|
||||
"account-2": nil,
|
||||
},
|
||||
}
|
||||
checker := newTestChecker(provider, ClientHealth{
|
||||
Healthy: true,
|
||||
ManagementConnected: true,
|
||||
SignalConnected: true,
|
||||
RelaysConnected: 1,
|
||||
RelaysTotal: 1,
|
||||
})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetInitialSyncComplete()
|
||||
|
||||
assert.True(t, checker.StartupProbe(context.Background()),
|
||||
"startup probe must pass when all clients are healthy")
|
||||
}
|
||||
|
||||
// TestStartupProbe_MixedHealthClients verifies that if any single client is
|
||||
// unhealthy, the startup probe fails (all-or-nothing).
|
||||
func TestStartupProbe_MixedHealthClients(t *testing.T) {
|
||||
provider := &mockClientProvider{
|
||||
clients: map[types.AccountID]*embed.Client{
|
||||
"healthy-account": nil,
|
||||
"unhealthy-account": nil,
|
||||
},
|
||||
}
|
||||
|
||||
checker := NewChecker(nil, provider)
|
||||
checker.checkHealth = func(cl *embed.Client) ClientHealth {
|
||||
// We identify accounts by their position in the map iteration; since we
|
||||
// can't control map order, make exactly one unhealthy via counter.
|
||||
return ClientHealth{Healthy: false}
|
||||
}
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetInitialSyncComplete()
|
||||
|
||||
assert.False(t, checker.StartupProbe(context.Background()),
|
||||
"startup probe must fail if any client is unhealthy")
|
||||
}
|
||||
|
||||
// TestStartupProbe_RequiresAllConditions ensures that each individual
|
||||
// prerequisite (management, sync, clients) is necessary. The probe must not
|
||||
// pass if any one is missing.
|
||||
func TestStartupProbe_RequiresAllConditions(t *testing.T) {
|
||||
provider := &mockClientProvider{
|
||||
clients: map[types.AccountID]*embed.Client{
|
||||
"account-1": nil,
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("no management", func(t *testing.T) {
|
||||
checker := newTestChecker(provider, ClientHealth{Healthy: true})
|
||||
checker.SetInitialSyncComplete()
|
||||
// management NOT connected
|
||||
assert.False(t, checker.StartupProbe(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("no sync", func(t *testing.T) {
|
||||
checker := newTestChecker(provider, ClientHealth{Healthy: true})
|
||||
checker.SetManagementConnected(true)
|
||||
// sync NOT complete
|
||||
assert.False(t, checker.StartupProbe(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("unhealthy client", func(t *testing.T) {
|
||||
checker := newTestChecker(provider, ClientHealth{Healthy: false})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetInitialSyncComplete()
|
||||
assert.False(t, checker.StartupProbe(context.Background()))
|
||||
})
|
||||
|
||||
t.Run("all conditions met", func(t *testing.T) {
|
||||
checker := newTestChecker(provider, ClientHealth{Healthy: true})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetInitialSyncComplete()
|
||||
assert.True(t, checker.StartupProbe(context.Background()))
|
||||
})
|
||||
}
|
||||
|
||||
// TestStartupProbe_ConcurrentAccess runs the startup probe from many
|
||||
// goroutines simultaneously to check for races.
|
||||
func TestStartupProbe_ConcurrentAccess(t *testing.T) {
|
||||
provider := &mockClientProvider{
|
||||
clients: map[types.AccountID]*embed.Client{
|
||||
"account-1": nil,
|
||||
"account-2": nil,
|
||||
},
|
||||
}
|
||||
checker := newTestChecker(provider, ClientHealth{Healthy: true})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetInitialSyncComplete()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
const goroutines = 50
|
||||
results := make([]bool, goroutines)
|
||||
|
||||
for i := range goroutines {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
results[idx] = checker.StartupProbe(context.Background())
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, r := range results {
|
||||
assert.True(t, r, "goroutine %d got unexpected result", i)
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartupProbe_CancelledContext verifies that a cancelled context causes
|
||||
// the probe to report unhealthy when client checks are needed.
|
||||
func TestStartupProbe_CancelledContext(t *testing.T) {
|
||||
t.Run("no management bypasses context", func(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
// Should be false because management isn't connected, context is irrelevant.
|
||||
assert.False(t, checker.StartupProbe(ctx))
|
||||
})
|
||||
|
||||
t.Run("with clients and cancelled context", func(t *testing.T) {
|
||||
provider := &mockClientProvider{
|
||||
clients: map[types.AccountID]*embed.Client{
|
||||
"account-1": nil,
|
||||
},
|
||||
}
|
||||
checker := NewChecker(nil, provider)
|
||||
// Use the real checkHealth path — a cancelled context should cause
|
||||
// the semaphore acquisition to fail, reporting unhealthy.
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetInitialSyncComplete()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
assert.False(t, checker.StartupProbe(ctx),
|
||||
"cancelled context must result in unhealthy when clients exist")
|
||||
})
|
||||
}
|
||||
|
||||
// TestHandler_Startup_EmptyServiceList verifies the HTTP startup endpoint
|
||||
// returns 200 when management is connected, sync is complete, and there are
|
||||
// no services/clients.
|
||||
func TestHandler_Startup_EmptyServiceList(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetInitialSyncComplete()
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/startup", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "ok", resp.Status)
|
||||
assert.True(t, resp.Checks["management_connected"])
|
||||
assert.True(t, resp.Checks["initial_sync_complete"])
|
||||
assert.True(t, resp.Checks["all_clients_healthy"])
|
||||
assert.Empty(t, resp.Clients)
|
||||
}
|
||||
|
||||
// TestHandler_Startup_WithUnhealthyClients verifies that the HTTP startup
|
||||
// endpoint returns 503 when clients exist but are not yet healthy.
|
||||
func TestHandler_Startup_WithUnhealthyClients(t *testing.T) {
|
||||
provider := &mockClientProvider{
|
||||
clients: map[types.AccountID]*embed.Client{
|
||||
"account-1": nil,
|
||||
},
|
||||
}
|
||||
checker := newTestChecker(provider, ClientHealth{Healthy: false, Error: "starting"})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetInitialSyncComplete()
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/startup", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "fail", resp.Status)
|
||||
assert.True(t, resp.Checks["management_connected"])
|
||||
assert.True(t, resp.Checks["initial_sync_complete"])
|
||||
assert.False(t, resp.Checks["all_clients_healthy"])
|
||||
require.Contains(t, resp.Clients, types.AccountID("account-1"))
|
||||
assert.Equal(t, "starting", resp.Clients["account-1"].Error)
|
||||
}
|
||||
|
||||
// TestHandler_Startup_WithHealthyClients verifies the HTTP startup endpoint
|
||||
// returns 200 once clients are healthy.
|
||||
func TestHandler_Startup_WithHealthyClients(t *testing.T) {
|
||||
provider := &mockClientProvider{
|
||||
clients: map[types.AccountID]*embed.Client{
|
||||
"account-1": nil,
|
||||
},
|
||||
}
|
||||
checker := newTestChecker(provider, ClientHealth{
|
||||
Healthy: true,
|
||||
ManagementConnected: true,
|
||||
SignalConnected: true,
|
||||
RelaysConnected: 1,
|
||||
RelaysTotal: 1,
|
||||
})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetInitialSyncComplete()
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/startup", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "ok", resp.Status)
|
||||
assert.True(t, resp.Checks["all_clients_healthy"])
|
||||
}
|
||||
|
||||
// TestHandler_Startup_NotComplete verifies the startup handler returns 503
|
||||
// when prerequisites aren't met.
|
||||
func TestHandler_Startup_NotComplete(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/startup", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "fail", resp.Status)
|
||||
}
|
||||
|
||||
func TestChecker_Handler_Liveness(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "ok", resp.Status)
|
||||
}
|
||||
|
||||
func TestChecker_Handler_Readiness_NotReady(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "fail", resp.Status)
|
||||
assert.False(t, resp.Checks["management_connected"])
|
||||
}
|
||||
|
||||
func TestChecker_Handler_Readiness_Ready(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "ok", resp.Status)
|
||||
assert.True(t, resp.Checks["management_connected"])
|
||||
}
|
||||
|
||||
func TestChecker_Handler_Full(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "ok", resp.Status)
|
||||
assert.NotNil(t, resp.Checks)
|
||||
// Clients may be empty map when no clients exist.
|
||||
assert.Empty(t, resp.Clients)
|
||||
}
|
||||
|
||||
func TestChecker_SetShuttingDown(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
assert.True(t, checker.ReadinessProbe(), "should be ready before shutdown")
|
||||
|
||||
checker.SetShuttingDown()
|
||||
|
||||
assert.False(t, checker.ReadinessProbe(), "should not be ready after shutdown")
|
||||
}
|
||||
|
||||
func TestChecker_Handler_Readiness_ShuttingDown(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
checker.SetShuttingDown()
|
||||
handler := checker.Handler()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/ready", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, rec.Code)
|
||||
|
||||
var resp ProbeResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&resp))
|
||||
assert.Equal(t, "fail", resp.Status)
|
||||
}
|
||||
|
||||
func TestNewServer_WithMetricsHandler(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
metricsHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("metrics"))
|
||||
})
|
||||
|
||||
srv := NewServer(":0", checker, nil, metricsHandler)
|
||||
require.NotNil(t, srv)
|
||||
|
||||
// Verify health endpoint still works through the mux.
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
// Verify metrics endpoint is mounted.
|
||||
req = httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "metrics", rec.Body.String())
|
||||
}
|
||||
|
||||
func TestNewServer_WithoutMetricsHandler(t *testing.T) {
|
||||
checker := NewChecker(nil, &mockClientProvider{})
|
||||
checker.SetManagementConnected(true)
|
||||
|
||||
srv := NewServer(":0", checker, nil, nil)
|
||||
require.NotNil(t, srv)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/healthz/live", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
srv.server.Handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
281
proxy/internal/k8s/lease.go
Normal file
281
proxy/internal/k8s/lease.go
Normal file
@@ -0,0 +1,281 @@
|
||||
// Package k8s provides a lightweight Kubernetes API client for coordination
|
||||
// Leases. It uses raw HTTP calls against the mounted service account
|
||||
// credentials, avoiding a dependency on client-go.
|
||||
package k8s
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
saTokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" //nolint:gosec
|
||||
saNamespacePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
|
||||
saCACertPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt"
|
||||
|
||||
leaseAPIPath = "/apis/coordination.k8s.io/v1"
|
||||
)
|
||||
|
||||
// ErrConflict is returned when a Lease update fails due to a
|
||||
// resourceVersion mismatch (another writer updated the object first).
|
||||
var ErrConflict = errors.New("conflict: resource version mismatch")
|
||||
|
||||
// Lease represents a coordination.k8s.io/v1 Lease object with only the
|
||||
// fields needed for distributed locking.
|
||||
type Lease struct {
|
||||
APIVersion string `json:"apiVersion"`
|
||||
Kind string `json:"kind"`
|
||||
Metadata LeaseMetadata `json:"metadata"`
|
||||
Spec LeaseSpec `json:"spec"`
|
||||
}
|
||||
|
||||
// LeaseMetadata holds the standard k8s object metadata fields used by Leases.
|
||||
type LeaseMetadata struct {
|
||||
Name string `json:"name"`
|
||||
Namespace string `json:"namespace,omitempty"`
|
||||
ResourceVersion string `json:"resourceVersion,omitempty"`
|
||||
Annotations map[string]string `json:"annotations,omitempty"`
|
||||
}
|
||||
|
||||
// LeaseSpec holds the Lease specification fields.
|
||||
type LeaseSpec struct {
|
||||
HolderIdentity *string `json:"holderIdentity"`
|
||||
LeaseDurationSeconds *int32 `json:"leaseDurationSeconds,omitempty"`
|
||||
AcquireTime *MicroTime `json:"acquireTime"`
|
||||
RenewTime *MicroTime `json:"renewTime"`
|
||||
}
|
||||
|
||||
// MicroTime wraps time.Time with Kubernetes MicroTime JSON formatting.
|
||||
type MicroTime struct {
|
||||
time.Time
|
||||
}
|
||||
|
||||
const microTimeFormat = "2006-01-02T15:04:05.000000Z"
|
||||
|
||||
// MarshalJSON implements json.Marshaler with k8s MicroTime format.
|
||||
func (t *MicroTime) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(t.UTC().Format(microTimeFormat))
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler with k8s MicroTime format.
|
||||
func (t *MicroTime) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
if s == "" {
|
||||
t.Time = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
parsed, err := time.Parse(microTimeFormat, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse MicroTime %q: %w", s, err)
|
||||
}
|
||||
t.Time = parsed
|
||||
return nil
|
||||
}
|
||||
|
||||
// LeaseClient talks to the Kubernetes coordination API using raw HTTP.
|
||||
type LeaseClient struct {
|
||||
baseURL string
|
||||
namespace string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewLeaseClient creates a client that authenticates via the pod's
|
||||
// mounted service account. It reads the namespace and CA certificate
|
||||
// at construction time (they don't rotate) but reads the bearer token
|
||||
// fresh on each request (tokens rotate).
|
||||
func NewLeaseClient() (*LeaseClient, error) {
|
||||
host := os.Getenv("KUBERNETES_SERVICE_HOST")
|
||||
port := os.Getenv("KUBERNETES_SERVICE_PORT")
|
||||
if host == "" || port == "" {
|
||||
return nil, fmt.Errorf("KUBERNETES_SERVICE_HOST/PORT not set")
|
||||
}
|
||||
|
||||
ns, err := os.ReadFile(saNamespacePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read namespace from %s: %w", saNamespacePath, err)
|
||||
}
|
||||
|
||||
caCert, err := os.ReadFile(saCACertPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read CA cert from %s: %w", saCACertPath, err)
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("parse CA certificate from %s", saCACertPath)
|
||||
}
|
||||
|
||||
return &LeaseClient{
|
||||
baseURL: fmt.Sprintf("https://%s:%s", host, port),
|
||||
namespace: strings.TrimSpace(string(ns)),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: pool,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Namespace returns the namespace this client operates in.
|
||||
func (c *LeaseClient) Namespace() string {
|
||||
return c.namespace
|
||||
}
|
||||
|
||||
// Get retrieves a Lease by name. Returns (nil, nil) if the Lease does not exist.
|
||||
func (c *LeaseClient) Get(ctx context.Context, name string) (*Lease, error) {
|
||||
url := fmt.Sprintf("%s%s/namespaces/%s/leases/%s", c.baseURL, leaseAPIPath, c.namespace, name)
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, c.readError(resp)
|
||||
}
|
||||
|
||||
var lease Lease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&lease); err != nil {
|
||||
return nil, fmt.Errorf("decode lease response: %w", err)
|
||||
}
|
||||
return &lease, nil
|
||||
}
|
||||
|
||||
// Create creates a new Lease. Returns the created Lease with server-assigned
|
||||
// fields like resourceVersion populated.
|
||||
func (c *LeaseClient) Create(ctx context.Context, lease *Lease) (*Lease, error) {
|
||||
url := fmt.Sprintf("%s%s/namespaces/%s/leases", c.baseURL, leaseAPIPath, c.namespace)
|
||||
|
||||
lease.APIVersion = "coordination.k8s.io/v1"
|
||||
lease.Kind = "Lease"
|
||||
if lease.Metadata.Namespace == "" {
|
||||
lease.Metadata.Namespace = c.namespace
|
||||
}
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodPost, url, lease)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode == http.StatusConflict {
|
||||
return nil, ErrConflict
|
||||
}
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
return nil, c.readError(resp)
|
||||
}
|
||||
|
||||
var created Lease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&created); err != nil {
|
||||
return nil, fmt.Errorf("decode created lease: %w", err)
|
||||
}
|
||||
return &created, nil
|
||||
}
|
||||
|
||||
// Update replaces a Lease. The lease.Metadata.ResourceVersion must match
|
||||
// the current server value (optimistic concurrency). Returns ErrConflict
|
||||
// on version mismatch.
|
||||
func (c *LeaseClient) Update(ctx context.Context, lease *Lease) (*Lease, error) {
|
||||
url := fmt.Sprintf("%s%s/namespaces/%s/leases/%s", c.baseURL, leaseAPIPath, c.namespace, lease.Metadata.Name)
|
||||
|
||||
lease.APIVersion = "coordination.k8s.io/v1"
|
||||
lease.Kind = "Lease"
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodPut, url, lease)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode == http.StatusConflict {
|
||||
return nil, ErrConflict
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, c.readError(resp)
|
||||
}
|
||||
|
||||
var updated Lease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&updated); err != nil {
|
||||
return nil, fmt.Errorf("decode updated lease: %w", err)
|
||||
}
|
||||
return &updated, nil
|
||||
}
|
||||
|
||||
func (c *LeaseClient) doRequest(ctx context.Context, method, url string, body any) (*http.Response, error) {
|
||||
token, err := readToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read service account token: %w", err)
|
||||
}
|
||||
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
|
||||
func readToken() (string, error) {
|
||||
data, err := os.ReadFile(saTokenPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read %s: %w", saTokenPath, err)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
func (c *LeaseClient) readError(resp *http.Response) error {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("k8s API %s %d: %s", resp.Request.URL.Path, resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// LeaseNameForDomain returns a deterministic, DNS-label-safe Lease name
|
||||
// for the given domain. The domain is hashed to avoid dots and length issues.
|
||||
func LeaseNameForDomain(domain string) string {
|
||||
h := sha256.Sum256([]byte(domain))
|
||||
return "cert-lock-" + hex.EncodeToString(h[:8])
|
||||
}
|
||||
|
||||
// InCluster reports whether the process is running inside a Kubernetes pod
|
||||
// by checking for the KUBERNETES_SERVICE_HOST environment variable.
|
||||
func InCluster() bool {
|
||||
_, exists := os.LookupEnv("KUBERNETES_SERVICE_HOST")
|
||||
return exists
|
||||
}
|
||||
102
proxy/internal/k8s/lease_test.go
Normal file
102
proxy/internal/k8s/lease_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package k8s
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLeaseNameForDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain string
|
||||
}{
|
||||
{"example.com"},
|
||||
{"app.example.com"},
|
||||
{"another.domain.io"},
|
||||
}
|
||||
|
||||
seen := make(map[string]string)
|
||||
for _, tc := range tests {
|
||||
name := LeaseNameForDomain(tc.domain)
|
||||
|
||||
assert.True(t, len(name) <= 63, "must be valid DNS label length")
|
||||
assert.Regexp(t, `^cert-lock-[0-9a-f]{16}$`, name,
|
||||
"must match expected format for domain %q", tc.domain)
|
||||
|
||||
// Same input produces same output.
|
||||
assert.Equal(t, name, LeaseNameForDomain(tc.domain), "must be deterministic")
|
||||
|
||||
// Different domains produce different names.
|
||||
if prev, ok := seen[name]; ok {
|
||||
t.Errorf("collision: %q and %q both map to %s", prev, tc.domain, name)
|
||||
}
|
||||
seen[name] = tc.domain
|
||||
}
|
||||
}
|
||||
|
||||
func TestMicroTimeJSON(t *testing.T) {
|
||||
ts := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC)
|
||||
mt := &MicroTime{Time: ts}
|
||||
|
||||
data, err := json.Marshal(mt)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `"2024-06-15T10:30:00.000000Z"`, string(data))
|
||||
|
||||
var decoded MicroTime
|
||||
require.NoError(t, json.Unmarshal(data, &decoded))
|
||||
assert.True(t, ts.Equal(decoded.Time), "round-trip should preserve time")
|
||||
}
|
||||
|
||||
func TestMicroTimeNullJSON(t *testing.T) {
|
||||
// Null pointer serializes as JSON null via the Lease struct.
|
||||
spec := LeaseSpec{
|
||||
HolderIdentity: nil,
|
||||
AcquireTime: nil,
|
||||
RenewTime: nil,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(spec)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"acquireTime":null`)
|
||||
assert.Contains(t, string(data), `"renewTime":null`)
|
||||
}
|
||||
|
||||
func TestLeaseJSONRoundTrip(t *testing.T) {
|
||||
holder := "pod-abc"
|
||||
dur := int32(300)
|
||||
now := MicroTime{Time: time.Now().UTC().Truncate(time.Microsecond)}
|
||||
|
||||
original := Lease{
|
||||
APIVersion: "coordination.k8s.io/v1",
|
||||
Kind: "Lease",
|
||||
Metadata: LeaseMetadata{
|
||||
Name: "cert-lock-abcdef0123456789",
|
||||
Namespace: "default",
|
||||
ResourceVersion: "12345",
|
||||
Annotations: map[string]string{
|
||||
"netbird.io/domain": "app.example.com",
|
||||
},
|
||||
},
|
||||
Spec: LeaseSpec{
|
||||
HolderIdentity: &holder,
|
||||
LeaseDurationSeconds: &dur,
|
||||
AcquireTime: &now,
|
||||
RenewTime: &now,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded Lease
|
||||
require.NoError(t, json.Unmarshal(data, &decoded))
|
||||
|
||||
assert.Equal(t, original.Metadata.Name, decoded.Metadata.Name)
|
||||
assert.Equal(t, original.Metadata.ResourceVersion, decoded.Metadata.ResourceVersion)
|
||||
assert.Equal(t, *original.Spec.HolderIdentity, *decoded.Spec.HolderIdentity)
|
||||
assert.Equal(t, *original.Spec.LeaseDurationSeconds, *decoded.Spec.LeaseDurationSeconds)
|
||||
assert.True(t, original.Spec.AcquireTime.Equal(decoded.Spec.AcquireTime.Time))
|
||||
}
|
||||
149
proxy/internal/metrics/metrics.go
Normal file
149
proxy/internal/metrics/metrics.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
requestsTotal prometheus.Counter
|
||||
activeRequests prometheus.Gauge
|
||||
configuredDomains prometheus.Gauge
|
||||
pathsPerDomain *prometheus.GaugeVec
|
||||
requestDuration *prometheus.HistogramVec
|
||||
backendDuration *prometheus.HistogramVec
|
||||
}
|
||||
|
||||
func New(reg prometheus.Registerer) *Metrics {
|
||||
promFactory := promauto.With(reg)
|
||||
return &Metrics{
|
||||
requestsTotal: promFactory.NewCounter(prometheus.CounterOpts{
|
||||
Name: "netbird_proxy_requests_total",
|
||||
Help: "Total number of requests made to the netbird proxy",
|
||||
}),
|
||||
activeRequests: promFactory.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "netbird_proxy_active_requests_count",
|
||||
Help: "Current in-flight requests handled by the netbird proxy",
|
||||
}),
|
||||
configuredDomains: promFactory.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "netbird_proxy_domains_count",
|
||||
Help: "Current number of domains configured on the netbird proxy",
|
||||
}),
|
||||
pathsPerDomain: promFactory.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "netbird_proxy_paths_count",
|
||||
Help: "Current number of paths configured on the netbird proxy labelled by domain",
|
||||
},
|
||||
[]string{"domain"},
|
||||
),
|
||||
requestDuration: promFactory.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "netbird_proxy_request_duration_seconds",
|
||||
Help: "Duration of requests made to the netbird proxy",
|
||||
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||
},
|
||||
[]string{"status", "size", "method", "host", "path"},
|
||||
),
|
||||
backendDuration: promFactory.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "netbird_proxy_backend_duration_seconds",
|
||||
Help: "Duration of peer round trip time from the netbird proxy",
|
||||
Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
|
||||
},
|
||||
[]string{"status", "size", "method", "host", "path"},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
type responseInterceptor struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
size int
|
||||
}
|
||||
|
||||
func (w *responseInterceptor) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (w *responseInterceptor) Write(b []byte) (int, error) {
|
||||
size, err := w.ResponseWriter.Write(b)
|
||||
w.size += size
|
||||
return size, err
|
||||
}
|
||||
|
||||
func (m *Metrics) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
m.requestsTotal.Inc()
|
||||
m.activeRequests.Inc()
|
||||
|
||||
interceptor := &responseInterceptor{ResponseWriter: w}
|
||||
|
||||
start := time.Now()
|
||||
next.ServeHTTP(interceptor, r)
|
||||
duration := time.Since(start)
|
||||
|
||||
m.activeRequests.Desc()
|
||||
m.requestDuration.With(prometheus.Labels{
|
||||
"status": strconv.Itoa(interceptor.status),
|
||||
"size": strconv.Itoa(interceptor.size),
|
||||
"method": r.Method,
|
||||
"host": r.Host,
|
||||
"path": r.URL.Path,
|
||||
}).Observe(duration.Seconds())
|
||||
})
|
||||
}
|
||||
|
||||
type roundTripperFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return f(r)
|
||||
}
|
||||
|
||||
func (m *Metrics) RoundTripper(next http.RoundTripper) http.RoundTripper {
|
||||
return roundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
labels := prometheus.Labels{
|
||||
"method": req.Method,
|
||||
"host": req.Host,
|
||||
// Fill potentially empty labels with default values to avoid cardinality issues.
|
||||
"path": "/",
|
||||
"status": "0",
|
||||
"size": "0",
|
||||
}
|
||||
if req.URL != nil {
|
||||
labels["path"] = req.URL.Path
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
res, err := next.RoundTrip(req)
|
||||
duration := time.Since(start)
|
||||
|
||||
// Not all labels will be available if there was an error.
|
||||
if res != nil {
|
||||
labels["status"] = strconv.Itoa(res.StatusCode)
|
||||
labels["size"] = strconv.Itoa(int(res.ContentLength))
|
||||
}
|
||||
|
||||
m.backendDuration.With(labels).Observe(duration.Seconds())
|
||||
|
||||
return res, err
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Metrics) AddMapping(mapping proxy.Mapping) {
|
||||
m.configuredDomains.Inc()
|
||||
m.pathsPerDomain.With(prometheus.Labels{
|
||||
"domain": mapping.Host,
|
||||
}).Set(float64(len(mapping.Paths)))
|
||||
}
|
||||
|
||||
func (m *Metrics) RemoveMapping(mapping proxy.Mapping) {
|
||||
m.configuredDomains.Dec()
|
||||
m.pathsPerDomain.With(prometheus.Labels{
|
||||
"domain": mapping.Host,
|
||||
}).Set(0)
|
||||
}
|
||||
67
proxy/internal/metrics/metrics_test.go
Normal file
67
proxy/internal/metrics/metrics_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package metrics_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/netbirdio/netbird/proxy/internal/metrics"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
type testRoundTripper struct {
|
||||
response *http.Response
|
||||
err error
|
||||
}
|
||||
|
||||
func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.response, t.err
|
||||
}
|
||||
|
||||
func TestMetrics_RoundTripper(t *testing.T) {
|
||||
testResponse := http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: http.NoBody,
|
||||
}
|
||||
|
||||
tests := map[string]struct {
|
||||
roundTripper http.RoundTripper
|
||||
request *http.Request
|
||||
response *http.Response
|
||||
err error
|
||||
}{
|
||||
"ok": {
|
||||
roundTripper: &testRoundTripper{response: &testResponse},
|
||||
request: &http.Request{Method: "GET", URL: &url.URL{Path: "/foo"}},
|
||||
response: &testResponse,
|
||||
},
|
||||
"nil url": {
|
||||
roundTripper: &testRoundTripper{response: &testResponse},
|
||||
request: &http.Request{Method: "GET", URL: nil},
|
||||
response: &testResponse,
|
||||
},
|
||||
"nil response": {
|
||||
roundTripper: &testRoundTripper{response: nil},
|
||||
request: &http.Request{Method: "GET", URL: &url.URL{Path: "/foo"}},
|
||||
},
|
||||
}
|
||||
|
||||
m := metrics.New(prometheus.NewRegistry())
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
rt := m.RoundTripper(test.roundTripper)
|
||||
res, err := rt.RoundTrip(test.request)
|
||||
if res != nil && res.Body != nil {
|
||||
defer res.Body.Close()
|
||||
}
|
||||
if diff := cmp.Diff(test.err, err); diff != "" {
|
||||
t.Errorf("Incorrect error (-want +got):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(test.response, res); diff != "" {
|
||||
t.Errorf("Incorrect response (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
187
proxy/internal/proxy/context.go
Normal file
187
proxy/internal/proxy/context.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
type requestContextKey string
|
||||
|
||||
const (
|
||||
serviceIdKey requestContextKey = "serviceId"
|
||||
accountIdKey requestContextKey = "accountId"
|
||||
capturedDataKey requestContextKey = "capturedData"
|
||||
)
|
||||
|
||||
// ResponseOrigin indicates where a response was generated.
|
||||
type ResponseOrigin int
|
||||
|
||||
const (
|
||||
// OriginBackend means the response came from the backend service.
|
||||
OriginBackend ResponseOrigin = iota
|
||||
// OriginNoRoute means the proxy had no matching host or path.
|
||||
OriginNoRoute
|
||||
// OriginProxyError means the proxy failed to reach the backend.
|
||||
OriginProxyError
|
||||
// OriginAuth means the proxy intercepted the request for authentication.
|
||||
OriginAuth
|
||||
)
|
||||
|
||||
func (o ResponseOrigin) String() string {
|
||||
switch o {
|
||||
case OriginNoRoute:
|
||||
return "no_route"
|
||||
case OriginProxyError:
|
||||
return "proxy_error"
|
||||
case OriginAuth:
|
||||
return "auth"
|
||||
default:
|
||||
return "backend"
|
||||
}
|
||||
}
|
||||
|
||||
// CapturedData is a mutable struct that allows downstream handlers
|
||||
// to pass data back up the middleware chain.
|
||||
type CapturedData struct {
|
||||
mu sync.RWMutex
|
||||
RequestID string
|
||||
ServiceId string
|
||||
AccountId types.AccountID
|
||||
Origin ResponseOrigin
|
||||
ClientIP string
|
||||
UserID string
|
||||
AuthMethod string
|
||||
}
|
||||
|
||||
// GetRequestID safely gets the request ID
|
||||
func (c *CapturedData) GetRequestID() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.RequestID
|
||||
}
|
||||
|
||||
// SetServiceId safely sets the service ID
|
||||
func (c *CapturedData) SetServiceId(serviceId string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.ServiceId = serviceId
|
||||
}
|
||||
|
||||
// GetServiceId safely gets the service ID
|
||||
func (c *CapturedData) GetServiceId() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.ServiceId
|
||||
}
|
||||
|
||||
// SetAccountId safely sets the account ID
|
||||
func (c *CapturedData) SetAccountId(accountId types.AccountID) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.AccountId = accountId
|
||||
}
|
||||
|
||||
// GetAccountId safely gets the account ID
|
||||
func (c *CapturedData) GetAccountId() types.AccountID {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.AccountId
|
||||
}
|
||||
|
||||
// SetOrigin safely sets the response origin
|
||||
func (c *CapturedData) SetOrigin(origin ResponseOrigin) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Origin = origin
|
||||
}
|
||||
|
||||
// GetOrigin safely gets the response origin
|
||||
func (c *CapturedData) GetOrigin() ResponseOrigin {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.Origin
|
||||
}
|
||||
|
||||
// SetClientIP safely sets the resolved client IP.
|
||||
func (c *CapturedData) SetClientIP(ip string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.ClientIP = ip
|
||||
}
|
||||
|
||||
// GetClientIP safely gets the resolved client IP.
|
||||
func (c *CapturedData) GetClientIP() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.ClientIP
|
||||
}
|
||||
|
||||
// SetUserID safely sets the authenticated user ID.
|
||||
func (c *CapturedData) SetUserID(userID string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.UserID = userID
|
||||
}
|
||||
|
||||
// GetUserID safely gets the authenticated user ID.
|
||||
func (c *CapturedData) GetUserID() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.UserID
|
||||
}
|
||||
|
||||
// SetAuthMethod safely sets the authentication method used.
|
||||
func (c *CapturedData) SetAuthMethod(method string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.AuthMethod = method
|
||||
}
|
||||
|
||||
// GetAuthMethod safely gets the authentication method used.
|
||||
func (c *CapturedData) GetAuthMethod() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.AuthMethod
|
||||
}
|
||||
|
||||
// WithCapturedData adds a CapturedData struct to the context
|
||||
func WithCapturedData(ctx context.Context, data *CapturedData) context.Context {
|
||||
return context.WithValue(ctx, capturedDataKey, data)
|
||||
}
|
||||
|
||||
// CapturedDataFromContext retrieves the CapturedData from context
|
||||
func CapturedDataFromContext(ctx context.Context) *CapturedData {
|
||||
v := ctx.Value(capturedDataKey)
|
||||
data, ok := v.(*CapturedData)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func withServiceId(ctx context.Context, serviceId string) context.Context {
|
||||
return context.WithValue(ctx, serviceIdKey, serviceId)
|
||||
}
|
||||
|
||||
func ServiceIdFromContext(ctx context.Context) string {
|
||||
v := ctx.Value(serviceIdKey)
|
||||
serviceId, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return serviceId
|
||||
}
|
||||
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
|
||||
return context.WithValue(ctx, accountIdKey, accountId)
|
||||
}
|
||||
|
||||
func AccountIdFromContext(ctx context.Context) types.AccountID {
|
||||
v := ctx.Value(accountIdKey)
|
||||
accountId, ok := v.(types.AccountID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return accountId
|
||||
}
|
||||
130
proxy/internal/proxy/proxy_bench_test.go
Normal file
130
proxy/internal/proxy/proxy_bench_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package proxy_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
type nopTransport struct{}
|
||||
|
||||
func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: http.NoBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func BenchmarkServeHTTP(b *testing.B) {
|
||||
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: rand.Text(),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: "app.example.com",
|
||||
Paths: map[string]*url.URL{
|
||||
"/": {
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://app.example.com", nil)
|
||||
req.Host = "app.example.com"
|
||||
req.RemoteAddr = "203.0.113.50:12345"
|
||||
|
||||
for b.Loop() {
|
||||
rp.ServeHTTP(httptest.NewRecorder(), req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkServeHTTPHostCount(b *testing.B) {
|
||||
hostCounts := []int{1, 10, 100, 1_000, 10_000}
|
||||
|
||||
for _, hostCount := range hostCounts {
|
||||
b.Run(fmt.Sprintf("hosts=%d", hostCount), func(b *testing.B) {
|
||||
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
|
||||
|
||||
var target string
|
||||
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(hostCount)))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
for i := range hostCount {
|
||||
id := rand.Text()
|
||||
host := fmt.Sprintf("%s.example.com", id)
|
||||
if int64(i) == targetIndex.Int64() {
|
||||
target = id
|
||||
}
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: id,
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: host,
|
||||
Paths: map[string]*url.URL{
|
||||
"/": {
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://"+target+"/", nil)
|
||||
req.Host = target
|
||||
req.RemoteAddr = "203.0.113.50:12345"
|
||||
|
||||
for b.Loop() {
|
||||
rp.ServeHTTP(httptest.NewRecorder(), req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkServeHTTPPathCount(b *testing.B) {
|
||||
pathCounts := []int{1, 5, 10, 25, 50}
|
||||
|
||||
for _, pathCount := range pathCounts {
|
||||
b.Run(fmt.Sprintf("paths=%d", pathCount), func(b *testing.B) {
|
||||
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
|
||||
|
||||
var target string
|
||||
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(pathCount)))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
paths := make(map[string]*url.URL, pathCount)
|
||||
for i := range pathCount {
|
||||
path := "/" + rand.Text()
|
||||
if int64(i) == targetIndex.Int64() {
|
||||
target = path
|
||||
}
|
||||
paths[path] = &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
|
||||
}
|
||||
}
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: rand.Text(),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: "app.example.com",
|
||||
Paths: paths,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://app.example.com"+target, nil)
|
||||
req.Host = "app.example.com"
|
||||
req.RemoteAddr = "203.0.113.50:12345"
|
||||
|
||||
for b.Loop() {
|
||||
rp.ServeHTTP(httptest.NewRecorder(), req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
406
proxy/internal/proxy/reverseproxy.go
Normal file
406
proxy/internal/proxy/reverseproxy.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
type ReverseProxy struct {
|
||||
transport http.RoundTripper
|
||||
// forwardedProto overrides the X-Forwarded-Proto header value.
|
||||
// Valid values: "auto" (detect from TLS), "http", "https".
|
||||
forwardedProto string
|
||||
// trustedProxies is a list of IP prefixes for trusted upstream proxies.
|
||||
// When the direct connection comes from a trusted proxy, forwarding
|
||||
// headers are preserved and appended to instead of being stripped.
|
||||
trustedProxies []netip.Prefix
|
||||
mappingsMux sync.RWMutex
|
||||
mappings map[string]Mapping
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewReverseProxy configures a new NetBird ReverseProxy.
|
||||
// This is a wrapper around an httputil.ReverseProxy set
|
||||
// to dynamically route requests based on internal mapping
|
||||
// between requested URLs and targets.
|
||||
// The internal mappings can be modified using the AddMapping
|
||||
// and RemoveMapping functions.
|
||||
func NewReverseProxy(transport http.RoundTripper, forwardedProto string, trustedProxies []netip.Prefix, logger *log.Logger) *ReverseProxy {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &ReverseProxy{
|
||||
transport: transport,
|
||||
forwardedProto: forwardedProto,
|
||||
trustedProxies: trustedProxies,
|
||||
mappings: make(map[string]Mapping),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
result, exists := p.findTargetForRequest(r)
|
||||
if !exists {
|
||||
if cd := CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(OriginNoRoute)
|
||||
}
|
||||
requestID := getRequestID(r)
|
||||
web.ServeErrorPage(w, r, http.StatusNotFound, "Service Not Found",
|
||||
"The requested service could not be found. Please check the URL, try refreshing, or check if the peer is running. If that doesn't work, see our documentation for help.",
|
||||
requestID, web.ErrorStatus{Proxy: true, Destination: false})
|
||||
return
|
||||
}
|
||||
|
||||
// Set the serviceId in the context for later retrieval.
|
||||
ctx := withServiceId(r.Context(), result.serviceID)
|
||||
// Set the accountId in the context for later retrieval (for middleware).
|
||||
ctx = withAccountId(ctx, result.accountID)
|
||||
// Set the accountId in the context for the roundtripper to use.
|
||||
ctx = roundtrip.WithAccountID(ctx, result.accountID)
|
||||
|
||||
// Also populate captured data if it exists (allows middleware to read after handler completes).
|
||||
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
|
||||
// pointer in the context, and mutate the struct here so outer middleware can read it.
|
||||
if capturedData := CapturedDataFromContext(ctx); capturedData != nil {
|
||||
capturedData.SetServiceId(result.serviceID)
|
||||
capturedData.SetAccountId(result.accountID)
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
|
||||
Transport: p.transport,
|
||||
ErrorHandler: proxyErrorHandler,
|
||||
}
|
||||
if result.rewriteRedirects {
|
||||
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose
|
||||
}
|
||||
rp.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
// rewriteFunc returns a Rewrite function for httputil.ReverseProxy that rewrites
|
||||
// inbound requests to target the backend service while setting security-relevant
|
||||
// forwarding headers and stripping proxy authentication credentials.
|
||||
// When passHostHeader is true, the original client Host header is preserved
|
||||
// instead of being rewritten to the backend's address.
|
||||
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool) func(r *httputil.ProxyRequest) {
|
||||
return func(r *httputil.ProxyRequest) {
|
||||
// Strip the matched path prefix from the incoming request path before
|
||||
// SetURL joins it with the target's base path, avoiding path duplication.
|
||||
if matchedPath != "" && matchedPath != "/" {
|
||||
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
|
||||
if r.Out.URL.Path == "" {
|
||||
r.Out.URL.Path = "/"
|
||||
}
|
||||
r.Out.URL.RawPath = ""
|
||||
}
|
||||
|
||||
r.SetURL(target)
|
||||
if passHostHeader {
|
||||
r.Out.Host = r.In.Host
|
||||
} else {
|
||||
r.Out.Host = target.Host
|
||||
}
|
||||
|
||||
clientIP := extractClientIP(r.In.RemoteAddr)
|
||||
|
||||
if IsTrustedProxy(clientIP, p.trustedProxies) {
|
||||
p.setTrustedForwardingHeaders(r, clientIP)
|
||||
} else {
|
||||
p.setUntrustedForwardingHeaders(r, clientIP)
|
||||
}
|
||||
|
||||
stripSessionCookie(r)
|
||||
stripSessionTokenQuery(r)
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteLocationFunc returns a ModifyResponse function that rewrites Location
|
||||
// headers in backend responses when they point to the backend's address,
|
||||
// replacing them with the public-facing host and scheme.
|
||||
func (p *ReverseProxy) rewriteLocationFunc(target *url.URL, matchedPath string, inReq *http.Request) func(*http.Response) error {
|
||||
publicHost := inReq.Host
|
||||
publicScheme := auth.ResolveProto(p.forwardedProto, inReq.TLS)
|
||||
|
||||
return func(resp *http.Response) error {
|
||||
location := resp.Header.Get("Location")
|
||||
if location == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
locURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse Location header %q: %w", location, err)
|
||||
}
|
||||
|
||||
// Only rewrite absolute URLs that point to the backend.
|
||||
if locURL.Host == "" || !hostsEqual(locURL, target) {
|
||||
return nil
|
||||
}
|
||||
|
||||
locURL.Host = publicHost
|
||||
locURL.Scheme = publicScheme
|
||||
|
||||
// Re-add the stripped path prefix so the client reaches the correct route.
|
||||
// TrimRight prevents double slashes when matchedPath has a trailing slash.
|
||||
if matchedPath != "" && matchedPath != "/" {
|
||||
locURL.Path = strings.TrimRight(matchedPath, "/") + "/" + strings.TrimLeft(locURL.Path, "/")
|
||||
}
|
||||
|
||||
resp.Header.Set("Location", locURL.String())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// hostsEqual compares two URL authorities, normalizing default ports per
|
||||
// RFC 3986 Section 6.2.3 (https://443 == https, http://80 == http).
|
||||
func hostsEqual(a, b *url.URL) bool {
|
||||
return normalizeHost(a) == normalizeHost(b)
|
||||
}
|
||||
|
||||
// normalizeHost strips the port from a URL's Host field if it matches the
|
||||
// scheme's default port (443 for https, 80 for http).
|
||||
func normalizeHost(u *url.URL) string {
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
return u.Host
|
||||
}
|
||||
if (u.Scheme == "https" && port == "443") || (u.Scheme == "http" && port == "80") {
|
||||
return host
|
||||
}
|
||||
return u.Host
|
||||
}
|
||||
|
||||
// setTrustedForwardingHeaders appends to the existing forwarding header chain
|
||||
// and preserves upstream-provided headers when the direct connection is from
|
||||
// a trusted proxy.
|
||||
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
|
||||
// Append the direct connection IP to the existing X-Forwarded-For chain.
|
||||
if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" {
|
||||
r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP)
|
||||
} else {
|
||||
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
|
||||
// Preserve upstream X-Real-IP if present; otherwise resolve through the chain.
|
||||
if realIP := r.In.Header.Get("X-Real-IP"); realIP != "" {
|
||||
r.Out.Header.Set("X-Real-IP", realIP)
|
||||
} else {
|
||||
resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies)
|
||||
r.Out.Header.Set("X-Real-IP", resolved)
|
||||
}
|
||||
|
||||
// Preserve upstream X-Forwarded-Host if present.
|
||||
if fwdHost := r.In.Header.Get("X-Forwarded-Host"); fwdHost != "" {
|
||||
r.Out.Header.Set("X-Forwarded-Host", fwdHost)
|
||||
} else {
|
||||
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
|
||||
}
|
||||
|
||||
// Trust upstream X-Forwarded-Proto; fall back to local resolution.
|
||||
if fwdProto := r.In.Header.Get("X-Forwarded-Proto"); fwdProto != "" {
|
||||
r.Out.Header.Set("X-Forwarded-Proto", fwdProto)
|
||||
} else {
|
||||
r.Out.Header.Set("X-Forwarded-Proto", auth.ResolveProto(p.forwardedProto, r.In.TLS))
|
||||
}
|
||||
|
||||
// Trust upstream X-Forwarded-Port; fall back to local computation.
|
||||
if fwdPort := r.In.Header.Get("X-Forwarded-Port"); fwdPort != "" {
|
||||
r.Out.Header.Set("X-Forwarded-Port", fwdPort)
|
||||
} else {
|
||||
resolvedProto := r.Out.Header.Get("X-Forwarded-Proto")
|
||||
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, resolvedProto))
|
||||
}
|
||||
}
|
||||
|
||||
// setUntrustedForwardingHeaders strips all incoming forwarding headers and
|
||||
// sets them fresh based on the direct connection. This is the default
|
||||
// behavior when no trusted proxies are configured or the direct connection
|
||||
// is from an untrusted source.
|
||||
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
|
||||
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
|
||||
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||
r.Out.Header.Set("X-Real-IP", clientIP)
|
||||
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
|
||||
r.Out.Header.Set("X-Forwarded-Proto", proto)
|
||||
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
|
||||
}
|
||||
|
||||
// stripSessionCookie removes the proxy's session cookie from the outgoing
|
||||
// request while preserving all other cookies.
|
||||
func stripSessionCookie(r *httputil.ProxyRequest) {
|
||||
cookies := r.In.Cookies()
|
||||
r.Out.Header.Del("Cookie")
|
||||
for _, c := range cookies {
|
||||
if c.Name != auth.SessionCookieName {
|
||||
r.Out.AddCookie(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stripSessionTokenQuery removes the OIDC session_token query parameter from
|
||||
// the outgoing URL to prevent credential leakage to backends.
|
||||
func stripSessionTokenQuery(r *httputil.ProxyRequest) {
|
||||
q := r.Out.URL.Query()
|
||||
if q.Has("session_token") {
|
||||
q.Del("session_token")
|
||||
r.Out.URL.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
// extractClientIP extracts the IP address from an http.Request.RemoteAddr
|
||||
// which is always in host:port format.
|
||||
func extractClientIP(remoteAddr string) string {
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return remoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// extractForwardedPort returns the port from the Host header if present,
|
||||
// otherwise defaults to the standard port for the resolved protocol.
|
||||
func extractForwardedPort(host, resolvedProto string) string {
|
||||
_, port, err := net.SplitHostPort(host)
|
||||
if err == nil && port != "" {
|
||||
return port
|
||||
}
|
||||
if resolvedProto == "https" {
|
||||
return "443"
|
||||
}
|
||||
return "80"
|
||||
}
|
||||
|
||||
// proxyErrorHandler handles errors from the reverse proxy and serves
|
||||
// user-friendly error pages instead of raw error responses.
|
||||
func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if cd := CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(OriginProxyError)
|
||||
}
|
||||
requestID := getRequestID(r)
|
||||
clientIP := getClientIP(r)
|
||||
title, message, code, status := classifyProxyError(err)
|
||||
|
||||
log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
|
||||
requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err)
|
||||
|
||||
web.ServeErrorPage(w, r, code, title, message, requestID, status)
|
||||
}
|
||||
|
||||
// getClientIP retrieves the resolved client IP from context.
|
||||
func getClientIP(r *http.Request) string {
|
||||
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
|
||||
return capturedData.GetClientIP()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getRequestID retrieves the request ID from context or returns empty string.
|
||||
func getRequestID(r *http.Request) string {
|
||||
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
|
||||
return capturedData.GetRequestID()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// classifyProxyError determines the appropriate error title, message, HTTP
|
||||
// status code, and component status based on the error type.
|
||||
func classifyProxyError(err error) (title, message string, code int, status web.ErrorStatus) {
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded),
|
||||
isNetTimeout(err):
|
||||
return "Request Timeout",
|
||||
"The request timed out while trying to reach the service. Please refresh the page and try again.",
|
||||
http.StatusGatewayTimeout,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case errors.Is(err, context.Canceled):
|
||||
return "Request Canceled",
|
||||
"The request was canceled before it could be completed. Please refresh the page and try again.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case errors.Is(err, roundtrip.ErrNoAccountID):
|
||||
return "Configuration Error",
|
||||
"The request could not be processed due to a configuration issue. Please refresh the page and try again.",
|
||||
http.StatusInternalServerError,
|
||||
web.ErrorStatus{Proxy: false, Destination: false}
|
||||
|
||||
case errors.Is(err, roundtrip.ErrNoPeerConnection),
|
||||
errors.Is(err, roundtrip.ErrClientStartFailed):
|
||||
return "Proxy Not Connected",
|
||||
"The proxy is not connected to the NetBird network. Please try again later or contact your administrator.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: false, Destination: false}
|
||||
|
||||
case errors.Is(err, roundtrip.ErrTooManyInflight):
|
||||
return "Service Overloaded",
|
||||
"The service is currently handling too many requests. Please try again shortly.",
|
||||
http.StatusServiceUnavailable,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case isConnectionRefused(err):
|
||||
return "Service Unavailable",
|
||||
"The connection to the service was refused. Please verify that the service is running and try again.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case isHostUnreachable(err):
|
||||
return "Peer Not Connected",
|
||||
"The connection to the peer could not be established. Please ensure the peer is running and connected to the NetBird network.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
}
|
||||
|
||||
return "Connection Error",
|
||||
"An unexpected error occurred while connecting to the service. Please try again later.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
}
|
||||
|
||||
// isConnectionRefused checks for connection refused errors by inspecting
|
||||
// the inner error of a *net.OpError. This handles both standard net errors
|
||||
// (where the inner error is a *os.SyscallError with "connection refused")
|
||||
// and gVisor netstack errors ("connection was refused").
|
||||
func isConnectionRefused(err error) bool {
|
||||
return opErrorContains(err, "refused")
|
||||
}
|
||||
|
||||
// isHostUnreachable checks for host/network unreachable errors by inspecting
|
||||
// the inner error of a *net.OpError. Covers standard net ("no route to host",
|
||||
// "network is unreachable") and gVisor ("host is unreachable", etc.).
|
||||
func isHostUnreachable(err error) bool {
|
||||
return opErrorContains(err, "unreachable") || opErrorContains(err, "no route to host")
|
||||
}
|
||||
|
||||
// isNetTimeout checks whether the error is a network timeout using the
|
||||
// net.Error interface.
|
||||
func isNetTimeout(err error) bool {
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr) && netErr.Timeout()
|
||||
}
|
||||
|
||||
// opErrorContains extracts the inner error from a *net.OpError and checks
|
||||
// whether its message contains the given substring. This handles gVisor
|
||||
// netstack errors which wrap tcpip errors as plain strings rather than
|
||||
// syscall.Errno values.
|
||||
func opErrorContains(err error, substr string) bool {
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) && opErr.Err != nil {
|
||||
return strings.Contains(opErr.Err.Error(), substr)
|
||||
}
|
||||
return false
|
||||
}
|
||||
966
proxy/internal/proxy/reverseproxy_test.go
Normal file
966
proxy/internal/proxy/reverseproxy_test.go
Normal file
@@ -0,0 +1,966 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
|
||||
t.Run("rewrites host to backend by default", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
|
||||
})
|
||||
|
||||
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "", true)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "public.example.com", pr.Out.Host,
|
||||
"Host header should be the original client host")
|
||||
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host,
|
||||
"URL host (used for TLS/SNI) must still point to the backend")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"should be set to the connecting client IP")
|
||||
})
|
||||
|
||||
t.Run("strips spoofed X-Forwarded-For from client", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"spoofed XFF must be replaced, not appended to")
|
||||
})
|
||||
|
||||
t.Run("strips spoofed X-Real-IP from client", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set("X-Real-IP", "10.0.0.1")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
|
||||
"spoofed X-Real-IP must be replaced")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "myapp.example.com:8443", pr.Out.Header.Get("X-Forwarded-Host"))
|
||||
})
|
||||
|
||||
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port"))
|
||||
})
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "443", pr.Out.Header.Get("X-Forwarded-Port"))
|
||||
})
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"))
|
||||
})
|
||||
|
||||
t.Run("auto detects https from TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
|
||||
t.Run("auto detects http without TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
|
||||
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "https"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
// No TLS, but forced to https
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
|
||||
t.Run("forced http proto", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "http"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
t.Run("strips nb_session cookie", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token-here"})
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
cookies := pr.Out.Cookies()
|
||||
for _, c := range cookies {
|
||||
assert.NotEqual(t, auth.SessionCookieName, c.Name,
|
||||
"proxy session cookie must not be forwarded to backend")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves other cookies", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token"})
|
||||
pr.In.AddCookie(&http.Cookie{Name: "app_session", Value: "app-value"})
|
||||
pr.In.AddCookie(&http.Cookie{Name: "tracking", Value: "track-value"})
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
cookies := pr.Out.Cookies()
|
||||
cookieNames := make([]string, 0, len(cookies))
|
||||
for _, c := range cookies {
|
||||
cookieNames = append(cookieNames, c.Name)
|
||||
}
|
||||
assert.Contains(t, cookieNames, "app_session", "non-proxy cookies should be preserved")
|
||||
assert.Contains(t, cookieNames, "tracking", "non-proxy cookies should be preserved")
|
||||
assert.NotContains(t, cookieNames, auth.SessionCookieName, "proxy cookie must be stripped")
|
||||
})
|
||||
|
||||
t.Run("handles request with no cookies", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.Header.Get("Cookie"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
t.Run("strips session_token query parameter", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.URL.Query().Get("session_token"),
|
||||
"OIDC session token must be stripped from backend request")
|
||||
assert.Equal(t, "keep", pr.Out.URL.Query().Get("other"),
|
||||
"other query parameters must be preserved")
|
||||
})
|
||||
|
||||
t.Run("preserves query when no session_token present", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/api?foo=bar&baz=qux", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "bar", pr.Out.URL.Query().Get("foo"))
|
||||
assert.Equal(t, "qux", pr.Out.URL.Query().Get("baz"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
|
||||
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080/app")
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "http", pr.Out.URL.Scheme)
|
||||
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host)
|
||||
assert.Equal(t, "/app/somepath", pr.Out.URL.Path,
|
||||
"SetURL should join the target base path with the request path")
|
||||
})
|
||||
|
||||
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||
rewrite := p.rewriteFunc(target, "/app", false)
|
||||
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.URL.Scheme)
|
||||
assert.Equal(t, "backend.example.org:443", pr.Out.URL.Host)
|
||||
assert.Equal(t, "/app/", pr.Out.URL.Path,
|
||||
"matched path prefix should be stripped before joining with target path")
|
||||
})
|
||||
|
||||
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||
rewrite := p.rewriteFunc(target, "/app", false)
|
||||
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/app/article/123", pr.Out.URL.Path,
|
||||
"subpath after matched prefix should be preserved")
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"},
|
||||
{"IPv6 with port", "[::1]:12345", "::1"},
|
||||
{"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"},
|
||||
{"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"},
|
||||
{"IPv6 without brackets fallback", "::1", "::1"},
|
||||
{"empty string fallback", "", ""},
|
||||
{"public IP", "203.0.113.50:9999", "203.0.113.50"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractForwardedPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
resolvedProto string
|
||||
expected string
|
||||
}{
|
||||
{"explicit port in host", "example.com:8443", "https", "8443"},
|
||||
{"explicit port overrides proto default", "example.com:9090", "http", "9090"},
|
||||
{"no port defaults to 443 for https", "example.com", "https", "443"},
|
||||
{"no port defaults to 80 for http", "example.com", "http", "80"},
|
||||
{"IPv6 host with port", "[::1]:8080", "http", "8080"},
|
||||
{"IPv6 host without port", "::1", "https", "443"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, extractForwardedPort(tt.host, tt.resolvedProto))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
trusted := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
|
||||
|
||||
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50, 10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"))
|
||||
})
|
||||
|
||||
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
pr.In.Header.Set("X-Real-IP", "203.0.113.50")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"))
|
||||
})
|
||||
|
||||
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
|
||||
"should resolve real client through trusted chain")
|
||||
})
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "original.example.com", pr.Out.Header.Get("X-Forwarded-Host"))
|
||||
})
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Port", "8443")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port"))
|
||||
})
|
||||
|
||||
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"),
|
||||
"should use configured forwardedProto as fallback")
|
||||
})
|
||||
|
||||
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host"))
|
||||
})
|
||||
|
||||
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
|
||||
pr.In.Header.Set("X-Real-IP", "evil")
|
||||
pr.In.Header.Set("X-Forwarded-Host", "evil.example.com")
|
||||
pr.In.Header.Set("X-Forwarded-Proto", "https")
|
||||
pr.In.Header.Set("X-Forwarded-Port", "9999")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"untrusted: XFF must be replaced")
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
|
||||
"untrusted: X-Real-IP must be replaced")
|
||||
assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host"),
|
||||
"untrusted: host must be from direct connection")
|
||||
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"),
|
||||
"untrusted: proto must be locally resolved")
|
||||
assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"),
|
||||
"untrusted: port must be locally computed")
|
||||
})
|
||||
|
||||
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"nil trusted list: should strip and use RemoteAddr")
|
||||
})
|
||||
|
||||
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"no upstream XFF: should set direct connection IP")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRewriteFunc_PathForwarding verifies what path the backend actually
|
||||
// receives given different configurations. This simulates the full pipeline:
|
||||
// management builds a target URL (with matching prefix baked into the path),
|
||||
// then the proxy strips the prefix and SetURL re-joins with the target path.
|
||||
func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
|
||||
// Simulate what ToProtoMapping does: target URL includes the matching
|
||||
// prefix as its path component, so the proxy strips-then-re-adds.
|
||||
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
|
||||
// Management builds: path="/heise", target="https://heise.de:443/heise"
|
||||
target, _ := url.Parse("https://heise.de:443/heise")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/heise/", pr.Out.URL.Path,
|
||||
"backend sees /heise/ because prefix is stripped then re-added by SetURL")
|
||||
})
|
||||
|
||||
t.Run("subpath under prefix also preserved", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443/heise")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/heise/article/123", pr.Out.URL.Path,
|
||||
"subpath is preserved on top of the re-added prefix")
|
||||
})
|
||||
|
||||
// What the behavior WOULD be if target URL had no path (true stripping)
|
||||
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/", pr.Out.URL.Path,
|
||||
"without path in target URL, backend sees / (true prefix stripping)")
|
||||
})
|
||||
|
||||
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/article/123", pr.Out.URL.Path,
|
||||
"without path in target URL, prefix is truly stripped")
|
||||
})
|
||||
|
||||
// Root path "/" — no stripping expected
|
||||
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.com:443/")
|
||||
rewrite := p.rewriteFunc(target, "/", false)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/heise", pr.Out.URL.Path,
|
||||
"root path match must not strip anything")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteLocationFunc(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }
|
||||
newReq := func(rawURL string) *http.Request {
|
||||
t.Helper()
|
||||
r := httptest.NewRequest(http.MethodGet, rawURL, nil)
|
||||
parsed, _ := url.Parse(rawURL)
|
||||
r.Host = parsed.Host
|
||||
return r
|
||||
}
|
||||
run := func(p *ReverseProxy, matchedPath string, inReq *http.Request, location string) (*http.Response, error) {
|
||||
t.Helper()
|
||||
modifyResp := p.rewriteLocationFunc(target, matchedPath, inReq) //nolint:bodyclose
|
||||
resp := &http.Response{Header: http.Header{}}
|
||||
if location != "" {
|
||||
resp.Header.Set("Location", location)
|
||||
}
|
||||
err := modifyResp(resp)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
t.Run("rewrites Location pointing to backend", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/page"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/login")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/login", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("does not rewrite Location pointing to other host", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"https://other.example.com/path")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://other.example.com/path", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("does not rewrite relative Location", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"/dashboard")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/dashboard", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("re-adds stripped path prefix", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/api", newReq("https://public.example.com/api/users"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/users")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/api/users", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("uses resolved proto for scheme", func(t *testing.T) {
|
||||
resp, err := run(newProxy("auto"), "", newReq("http://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/path")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://public.example.com/path", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("no-op when Location header is empty", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), "") //nolint:bodyclose
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("does not prepend root path prefix", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/", newReq("https://public.example.com/login"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/login")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/login", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
// --- Edge cases: query parameters and fragments ---
|
||||
|
||||
t.Run("preserves query parameters", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/login?redirect=%2Fdashboard&lang=en")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/login?redirect=%2Fdashboard&lang=en", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("preserves fragment", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/docs#section-2")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/docs#section-2", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("preserves query parameters and fragment together", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/search?q=test&page=1#results")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/search?q=test&page=1#results", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("preserves query parameters with path prefix re-added", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/api", newReq("https://public.example.com/api/search"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/search?q=hello")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/api/search?q=hello", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
// --- Edge cases: slash handling ---
|
||||
|
||||
t.Run("no double slash when matchedPath has trailing slash", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/api/", newReq("https://public.example.com/api/users"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/users")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/api/users", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("backend redirect to root with path prefix", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/app", newReq("https://public.example.com/app/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/app/", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("backend redirect to root with trailing-slash path prefix", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/app/", newReq("https://public.example.com/app/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/app/", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("preserves trailing slash on redirect path", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/path/")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/path/", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("backend redirect to bare root", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/page"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
// --- Edge cases: host/port matching ---
|
||||
|
||||
t.Run("does not rewrite when backend host matches but port differs", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:9090/other")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://backend.internal:9090/other", resp.Header.Get("Location"),
|
||||
"Different port means different host authority, must not rewrite")
|
||||
})
|
||||
|
||||
t.Run("rewrites when redirect omits default port matching target", func(t *testing.T) {
|
||||
// Target is backend.internal:8080, redirect is to backend.internal (no port).
|
||||
// These are different authorities, so should NOT rewrite.
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal/path")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://backend.internal/path", resp.Header.Get("Location"),
|
||||
"backend.internal != backend.internal:8080, must not rewrite")
|
||||
})
|
||||
|
||||
t.Run("rewrites when target has :443 but redirect omits it for https", func(t *testing.T) {
|
||||
// Target: heise.de:443, redirect: https://heise.de/path (no :443 because it's default)
|
||||
// Per RFC 3986, these are the same authority.
|
||||
target443, _ := url.Parse("https://heise.de:443")
|
||||
p := newProxy("https")
|
||||
modifyResp := p.rewriteLocationFunc(target443, "", newReq("https://public.example.com/")) //nolint:bodyclose
|
||||
resp := &http.Response{Header: http.Header{}}
|
||||
resp.Header.Set("Location", "https://heise.de/path")
|
||||
|
||||
err := modifyResp(resp)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/path", resp.Header.Get("Location"),
|
||||
"heise.de:443 and heise.de are the same for https")
|
||||
})
|
||||
|
||||
t.Run("rewrites when target has :80 but redirect omits it for http", func(t *testing.T) {
|
||||
target80, _ := url.Parse("http://backend.local:80")
|
||||
p := newProxy("http")
|
||||
modifyResp := p.rewriteLocationFunc(target80, "", newReq("http://public.example.com/")) //nolint:bodyclose
|
||||
resp := &http.Response{Header: http.Header{}}
|
||||
resp.Header.Set("Location", "http://backend.local/path")
|
||||
|
||||
err := modifyResp(resp)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://public.example.com/path", resp.Header.Get("Location"),
|
||||
"backend.local:80 and backend.local are the same for http")
|
||||
})
|
||||
|
||||
t.Run("rewrites when redirect has :443 but target omits it", func(t *testing.T) {
|
||||
targetNoPort, _ := url.Parse("https://heise.de")
|
||||
p := newProxy("https")
|
||||
modifyResp := p.rewriteLocationFunc(targetNoPort, "", newReq("https://public.example.com/")) //nolint:bodyclose
|
||||
resp := &http.Response{Header: http.Header{}}
|
||||
resp.Header.Set("Location", "https://heise.de:443/path")
|
||||
|
||||
err := modifyResp(resp)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/path", resp.Header.Get("Location"),
|
||||
"heise.de and heise.de:443 are the same for https")
|
||||
})
|
||||
|
||||
t.Run("does not conflate non-default ports", func(t *testing.T) {
|
||||
target8443, _ := url.Parse("https://backend.internal:8443")
|
||||
p := newProxy("https")
|
||||
modifyResp := p.rewriteLocationFunc(target8443, "", newReq("https://public.example.com/")) //nolint:bodyclose
|
||||
resp := &http.Response{Header: http.Header{}}
|
||||
resp.Header.Set("Location", "https://backend.internal/path")
|
||||
|
||||
err := modifyResp(resp)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://backend.internal/path", resp.Header.Get("Location"),
|
||||
"backend.internal:8443 != backend.internal (port 443), must not rewrite")
|
||||
})
|
||||
|
||||
// --- Edge cases: encoded paths ---
|
||||
|
||||
t.Run("preserves percent-encoded path segments", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/path%20with%20spaces/file%2Fname")
|
||||
|
||||
require.NoError(t, err)
|
||||
loc := resp.Header.Get("Location")
|
||||
assert.Contains(t, loc, "public.example.com")
|
||||
parsed, err := url.Parse(loc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/path with spaces/file/name", parsed.Path)
|
||||
})
|
||||
|
||||
t.Run("preserves encoded query parameters with path prefix", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/v1", newReq("https://public.example.com/v1/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/redirect?url=http%3A%2F%2Fexample.com")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/v1/redirect?url=http%3A%2F%2Fexample.com", resp.Header.Get("Location"))
|
||||
})
|
||||
}
|
||||
|
||||
// newProxyRequest creates an httputil.ProxyRequest suitable for testing
|
||||
// the Rewrite function. It simulates what httputil.ReverseProxy does internally:
|
||||
// Out is a shallow clone of In with headers copied.
|
||||
func newProxyRequest(t *testing.T, rawURL, remoteAddr string) *httputil.ProxyRequest {
|
||||
t.Helper()
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
require.NoError(t, err)
|
||||
|
||||
in := httptest.NewRequest(http.MethodGet, rawURL, nil)
|
||||
in.RemoteAddr = remoteAddr
|
||||
in.Host = parsed.Host
|
||||
|
||||
out := in.Clone(in.Context())
|
||||
out.Header = in.Header.Clone()
|
||||
|
||||
return &httputil.ProxyRequest{In: in, Out: out}
|
||||
}
|
||||
|
||||
func TestClassifyProxyError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantTitle string
|
||||
wantCode int
|
||||
wantStatus web.ErrorStatus
|
||||
}{
|
||||
{
|
||||
name: "context deadline exceeded",
|
||||
err: context.DeadlineExceeded,
|
||||
wantTitle: "Request Timeout",
|
||||
wantCode: http.StatusGatewayTimeout,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "wrapped deadline exceeded",
|
||||
err: fmt.Errorf("dial: %w", context.DeadlineExceeded),
|
||||
wantTitle: "Request Timeout",
|
||||
wantCode: http.StatusGatewayTimeout,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "context canceled",
|
||||
err: context.Canceled,
|
||||
wantTitle: "Request Canceled",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "no account ID",
|
||||
err: roundtrip.ErrNoAccountID,
|
||||
wantTitle: "Configuration Error",
|
||||
wantCode: http.StatusInternalServerError,
|
||||
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "no peer connection",
|
||||
err: fmt.Errorf("%w for account: abc", roundtrip.ErrNoPeerConnection),
|
||||
wantTitle: "Proxy Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "client not started",
|
||||
err: fmt.Errorf("%w: %w", roundtrip.ErrClientStartFailed, errors.New("engine init failed")),
|
||||
wantTitle: "Proxy Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "syscall ECONNREFUSED via os.SyscallError",
|
||||
err: &net.OpError{
|
||||
Op: "dial",
|
||||
Net: "tcp",
|
||||
Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED},
|
||||
},
|
||||
wantTitle: "Service Unavailable",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "gvisor connection was refused",
|
||||
err: &net.OpError{
|
||||
Op: "connect",
|
||||
Net: "tcp",
|
||||
Err: errors.New("connection was refused"),
|
||||
},
|
||||
wantTitle: "Service Unavailable",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "syscall EHOSTUNREACH via os.SyscallError",
|
||||
err: &net.OpError{
|
||||
Op: "dial",
|
||||
Net: "tcp",
|
||||
Err: &os.SyscallError{Syscall: "connect", Err: syscall.EHOSTUNREACH},
|
||||
},
|
||||
wantTitle: "Peer Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "syscall ENETUNREACH via os.SyscallError",
|
||||
err: &net.OpError{
|
||||
Op: "dial",
|
||||
Net: "tcp",
|
||||
Err: &os.SyscallError{Syscall: "connect", Err: syscall.ENETUNREACH},
|
||||
},
|
||||
wantTitle: "Peer Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "gvisor host is unreachable",
|
||||
err: &net.OpError{
|
||||
Op: "connect",
|
||||
Net: "tcp",
|
||||
Err: errors.New("host is unreachable"),
|
||||
},
|
||||
wantTitle: "Peer Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "gvisor network is unreachable",
|
||||
err: &net.OpError{
|
||||
Op: "connect",
|
||||
Net: "tcp",
|
||||
Err: errors.New("network is unreachable"),
|
||||
},
|
||||
wantTitle: "Peer Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "standard no route to host",
|
||||
err: &net.OpError{
|
||||
Op: "dial",
|
||||
Net: "tcp",
|
||||
Err: &os.SyscallError{Syscall: "connect", Err: syscall.EHOSTUNREACH},
|
||||
},
|
||||
wantTitle: "Peer Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "unknown error falls to default",
|
||||
err: errors.New("something unexpected"),
|
||||
wantTitle: "Connection Error",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
title, _, code, status := classifyProxyError(tt.err)
|
||||
assert.Equal(t, tt.wantTitle, title, "title")
|
||||
assert.Equal(t, tt.wantCode, code, "status code")
|
||||
assert.Equal(t, tt.wantStatus, status, "component status")
|
||||
})
|
||||
}
|
||||
}
|
||||
84
proxy/internal/proxy/servicemapping.go
Normal file
84
proxy/internal/proxy/servicemapping.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
type Mapping struct {
|
||||
ID string
|
||||
AccountID types.AccountID
|
||||
Host string
|
||||
Paths map[string]*url.URL
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
}
|
||||
|
||||
type targetResult struct {
|
||||
url *url.URL
|
||||
matchedPath string
|
||||
serviceID string
|
||||
accountID types.AccountID
|
||||
passHostHeader bool
|
||||
rewriteRedirects bool
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) {
|
||||
p.mappingsMux.RLock()
|
||||
defer p.mappingsMux.RUnlock()
|
||||
|
||||
// Strip port from host if present (e.g., "external.test:8443" -> "external.test")
|
||||
host := req.Host
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
m, exists := p.mappings[host]
|
||||
if !exists {
|
||||
p.logger.Debugf("no mapping found for host: %s", host)
|
||||
return targetResult{}, false
|
||||
}
|
||||
|
||||
// Sort paths by length (longest first) in a naive attempt to match the most specific route first.
|
||||
paths := make([]string, 0, len(m.Paths))
|
||||
for path := range m.Paths {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
sort.Slice(paths, func(i, j int) bool {
|
||||
return len(paths[i]) > len(paths[j])
|
||||
})
|
||||
|
||||
for _, path := range paths {
|
||||
if strings.HasPrefix(req.URL.Path, path) {
|
||||
target := m.Paths[path]
|
||||
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target)
|
||||
return targetResult{
|
||||
url: target,
|
||||
matchedPath: path,
|
||||
serviceID: m.ID,
|
||||
accountID: m.AccountID,
|
||||
passHostHeader: m.PassHostHeader,
|
||||
rewriteRedirects: m.RewriteRedirects,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
p.logger.Debugf("no path match for host: %s, path: %s", host, req.URL.Path)
|
||||
return targetResult{}, false
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) AddMapping(m Mapping) {
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
p.mappings[m.Host] = m
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) RemoveMapping(m Mapping) {
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
delete(p.mappings, m.Host)
|
||||
}
|
||||
60
proxy/internal/proxy/trustedproxy.go
Normal file
60
proxy/internal/proxy/trustedproxy.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes.
|
||||
func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
|
||||
if len(trusted) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, prefix := range trusted {
|
||||
if prefix.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list.
|
||||
// It walks the XFF chain right-to-left, skipping IPs that match trusted prefixes.
|
||||
// The first untrusted IP is the real client.
|
||||
//
|
||||
// If the trusted list is empty or remoteAddr is not trusted, it returns the
|
||||
// remoteAddr IP directly (ignoring any forwarding headers).
|
||||
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
|
||||
remoteIP := extractClientIP(remoteAddr)
|
||||
|
||||
if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) {
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
if xff == "" {
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
parts := strings.Split(xff, ",")
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
ip := strings.TrimSpace(parts[i])
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
if !IsTrustedProxy(ip, trusted) {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// All IPs in XFF are trusted; return the leftmost as best guess.
|
||||
if first := strings.TrimSpace(parts[0]); first != "" {
|
||||
return first
|
||||
}
|
||||
return remoteIP
|
||||
}
|
||||
129
proxy/internal/proxy/trustedproxy_test.go
Normal file
129
proxy/internal/proxy/trustedproxy_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsTrustedProxy(t *testing.T) {
|
||||
trusted := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("fd00::/8"),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
trusted []netip.Prefix
|
||||
want bool
|
||||
}{
|
||||
{"empty trusted list", "10.0.0.1", nil, false},
|
||||
{"IP within /8 prefix", "10.1.2.3", trusted, true},
|
||||
{"IP within /24 prefix", "192.168.1.100", trusted, true},
|
||||
{"IP outside all prefixes", "203.0.113.50", trusted, false},
|
||||
{"boundary IP just outside prefix", "192.168.2.1", trusted, false},
|
||||
{"unparsable IP", "not-an-ip", trusted, false},
|
||||
{"IPv6 in trusted range", "fd00::1", trusted, true},
|
||||
{"IPv6 outside range", "2001:db8::1", trusted, false},
|
||||
{"empty string", "", trusted, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, IsTrustedProxy(tt.ip, tt.trusted))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveClientIP(t *testing.T) {
|
||||
trusted := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("172.16.0.0/12"),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xff string
|
||||
trusted []netip.Prefix
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty trusted list returns RemoteAddr",
|
||||
remoteAddr: "203.0.113.50:9999",
|
||||
xff: "1.2.3.4",
|
||||
trusted: nil,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "untrusted RemoteAddr ignores XFF",
|
||||
remoteAddr: "203.0.113.50:9999",
|
||||
xff: "1.2.3.4, 10.0.0.1",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr with single client in XFF",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr walks past trusted entries in XFF",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50, 10.0.0.2, 172.16.0.5",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "",
|
||||
trusted: trusted,
|
||||
want: "10.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "all XFF IPs trusted returns leftmost",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "10.0.0.2, 172.16.0.1, 10.0.0.3",
|
||||
trusted: trusted,
|
||||
want: "10.0.0.2",
|
||||
},
|
||||
{
|
||||
name: "XFF with whitespace",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: " 203.0.113.50 , 10.0.0.2 ",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "XFF with empty segments",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50,,10.0.0.2",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "multi-hop with mixed trust",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "8.8.8.8, 203.0.113.50, 172.16.0.1",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr without port",
|
||||
remoteAddr: "10.0.0.1",
|
||||
xff: "203.0.113.50",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, ResolveClientIP(tt.remoteAddr, tt.xff, tt.trusted))
|
||||
})
|
||||
}
|
||||
}
|
||||
575
proxy/internal/roundtrip/netbird.go
Normal file
575
proxy/internal/roundtrip/netbird.go
Normal file
@@ -0,0 +1,575 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/client/embed"
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const deviceNamePrefix = "ingress-proxy-"
|
||||
|
||||
// backendKey identifies a backend by its host:port from the target URL.
|
||||
type backendKey = string
|
||||
|
||||
var (
|
||||
// ErrNoAccountID is returned when a request context is missing the account ID.
|
||||
ErrNoAccountID = errors.New("no account ID in request context")
|
||||
// ErrNoPeerConnection is returned when no embedded client exists for the account.
|
||||
ErrNoPeerConnection = errors.New("no peer connection found")
|
||||
// ErrClientStartFailed is returned when the embedded client fails to start.
|
||||
ErrClientStartFailed = errors.New("client start failed")
|
||||
// ErrTooManyInflight is returned when the per-backend in-flight limit is reached.
|
||||
ErrTooManyInflight = errors.New("too many in-flight requests")
|
||||
)
|
||||
|
||||
// domainInfo holds metadata about a registered domain.
|
||||
type domainInfo struct {
|
||||
serviceID string
|
||||
}
|
||||
|
||||
type domainNotification struct {
|
||||
domain domain.Domain
|
||||
serviceID string
|
||||
}
|
||||
|
||||
// clientEntry holds an embedded NetBird client and tracks which domains use it.
|
||||
type clientEntry struct {
|
||||
client *embed.Client
|
||||
transport *http.Transport
|
||||
domains map[domain.Domain]domainInfo
|
||||
createdAt time.Time
|
||||
started bool
|
||||
// Per-backend in-flight limiting keyed by target host:port.
|
||||
// TODO: clean up stale entries when backend targets change.
|
||||
inflightMu sync.Mutex
|
||||
inflightMap map[backendKey]chan struct{}
|
||||
maxInflight int
|
||||
}
|
||||
|
||||
// acquireInflight attempts to acquire an in-flight slot for the given backend.
|
||||
// It returns a release function that must always be called, and true on success.
|
||||
func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bool) {
|
||||
noop := func() {}
|
||||
if e.maxInflight <= 0 {
|
||||
return noop, true
|
||||
}
|
||||
|
||||
e.inflightMu.Lock()
|
||||
sem, exists := e.inflightMap[backend]
|
||||
if !exists {
|
||||
sem = make(chan struct{}, e.maxInflight)
|
||||
e.inflightMap[backend] = sem
|
||||
}
|
||||
e.inflightMu.Unlock()
|
||||
|
||||
select {
|
||||
case sem <- struct{}{}:
|
||||
return func() { <-sem }, true
|
||||
default:
|
||||
return noop, false
|
||||
}
|
||||
}
|
||||
|
||||
type statusNotifier interface {
|
||||
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
|
||||
}
|
||||
|
||||
type managementClient interface {
|
||||
CreateProxyPeer(ctx context.Context, req *proto.CreateProxyPeerRequest, opts ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error)
|
||||
}
|
||||
|
||||
// NetBird provides an http.RoundTripper implementation
|
||||
// backed by underlying NetBird connections.
|
||||
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
|
||||
type NetBird struct {
|
||||
mgmtAddr string
|
||||
proxyID string
|
||||
proxyAddr string
|
||||
wgPort int
|
||||
logger *log.Logger
|
||||
mgmtClient managementClient
|
||||
transportCfg transportConfig
|
||||
|
||||
clientsMux sync.RWMutex
|
||||
clients map[types.AccountID]*clientEntry
|
||||
initLogOnce sync.Once
|
||||
statusNotifier statusNotifier
|
||||
}
|
||||
|
||||
// ClientDebugInfo contains debug information about a client.
|
||||
type ClientDebugInfo struct {
|
||||
AccountID types.AccountID
|
||||
DomainCount int
|
||||
Domains domain.List
|
||||
HasClient bool
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// accountIDContextKey is the context key for storing the account ID.
|
||||
type accountIDContextKey struct{}
|
||||
|
||||
// AddPeer registers a domain for an account. If the account doesn't have a client yet,
|
||||
// one is created by authenticating with the management server using the provided token.
|
||||
// Multiple domains can share the same client.
|
||||
func (n *NetBird) AddPeer(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) error {
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
// Client already exists for this account, just register the domain
|
||||
entry.domains[d] = domainInfo{serviceID: serviceID}
|
||||
started := entry.started
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Debug("registered domain with existing client")
|
||||
|
||||
// If client is already started, notify this domain as connected immediately
|
||||
if started && n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), serviceID, string(d), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify status for existing client")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
entry, err := n.createClientEntry(ctx, accountID, d, authToken, serviceID)
|
||||
if err != nil {
|
||||
n.clientsMux.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
n.clients[accountID] = entry
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Info("created new client for account")
|
||||
|
||||
// Attempt to start the client in the background; if this fails we will
|
||||
// retry on the first request via RoundTrip.
|
||||
go n.runClientStartup(ctx, accountID, entry.client)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createClientEntry generates a WireGuard keypair, authenticates with management,
|
||||
// and creates an embedded NetBird client. Must be called with clientsMux held.
|
||||
func (n *NetBird) createClientEntry(ctx context.Context, accountID types.AccountID, d domain.Domain, authToken, serviceID string) (*clientEntry, error) {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": serviceID,
|
||||
}).Debug("generating WireGuard keypair for new peer")
|
||||
|
||||
privateKey, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate wireguard private key: %w", err)
|
||||
}
|
||||
publicKey := privateKey.PublicKey()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": serviceID,
|
||||
"public_key": publicKey.String(),
|
||||
}).Debug("authenticating new proxy peer with management")
|
||||
|
||||
resp, err := n.mgmtClient.CreateProxyPeer(ctx, &proto.CreateProxyPeerRequest{
|
||||
ServiceId: serviceID,
|
||||
AccountId: string(accountID),
|
||||
Token: authToken,
|
||||
WireguardPublicKey: publicKey.String(),
|
||||
Cluster: n.proxyAddr,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate proxy peer with management: %w", err)
|
||||
}
|
||||
if resp != nil && !resp.GetSuccess() {
|
||||
errMsg := "unknown error"
|
||||
if resp.ErrorMessage != nil {
|
||||
errMsg = *resp.ErrorMessage
|
||||
}
|
||||
return nil, fmt.Errorf("proxy peer authentication failed: %s", errMsg)
|
||||
}
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": serviceID,
|
||||
"public_key": publicKey.String(),
|
||||
}).Info("proxy peer authenticated successfully with management")
|
||||
|
||||
n.initLogOnce.Do(func() {
|
||||
if err := util.InitLog(log.WarnLevel.String(), util.LogConsole); err != nil {
|
||||
n.logger.WithField("account_id", accountID).Warnf("failed to initialize embedded client logging: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Create embedded NetBird client with the generated private key.
|
||||
// The peer has already been created via CreateProxyPeer RPC with the public key.
|
||||
client, err := embed.New(embed.Options{
|
||||
DeviceName: deviceNamePrefix + n.proxyID,
|
||||
ManagementURL: n.mgmtAddr,
|
||||
PrivateKey: privateKey.String(),
|
||||
LogLevel: log.WarnLevel.String(),
|
||||
BlockInbound: true,
|
||||
WireguardPort: &n.wgPort,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create netbird client: %w", err)
|
||||
}
|
||||
|
||||
// Create a transport using the client dialer. We do this instead of using
|
||||
// the client's HTTPClient to avoid issues with request validation that do
|
||||
// not work with reverse proxied requests.
|
||||
return &clientEntry{
|
||||
client: client,
|
||||
domains: map[domain.Domain]domainInfo{d: {serviceID: serviceID}},
|
||||
transport: &http.Transport{
|
||||
DialContext: client.DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: n.transportCfg.maxIdleConns,
|
||||
MaxIdleConnsPerHost: n.transportCfg.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: n.transportCfg.maxConnsPerHost,
|
||||
IdleConnTimeout: n.transportCfg.idleConnTimeout,
|
||||
TLSHandshakeTimeout: n.transportCfg.tlsHandshakeTimeout,
|
||||
ExpectContinueTimeout: n.transportCfg.expectContinueTimeout,
|
||||
ResponseHeaderTimeout: n.transportCfg.responseHeaderTimeout,
|
||||
WriteBufferSize: n.transportCfg.writeBufferSize,
|
||||
ReadBufferSize: n.transportCfg.readBufferSize,
|
||||
DisableCompression: n.transportCfg.disableCompression,
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
started: false,
|
||||
inflightMap: make(map[backendKey]chan struct{}),
|
||||
maxInflight: n.transportCfg.maxInflight,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// runClientStartup starts the client and notifies registered domains on success.
|
||||
func (n *NetBird) runClientStartup(ctx context.Context, accountID types.AccountID, client *embed.Client) {
|
||||
startCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
n.logger.WithField("account_id", accountID).Warn("netbird client start timed out, will retry on first request")
|
||||
} else {
|
||||
n.logger.WithField("account_id", accountID).WithError(err).Error("failed to start netbird client")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Mark client as started and collect domains to notify outside the lock.
|
||||
n.clientsMux.Lock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if exists {
|
||||
entry.started = true
|
||||
}
|
||||
var domainsToNotify []domainNotification
|
||||
if exists {
|
||||
for dom, info := range entry.domains {
|
||||
domainsToNotify = append(domainsToNotify, domainNotification{domain: dom, serviceID: info.serviceID})
|
||||
}
|
||||
}
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
if n.statusNotifier == nil {
|
||||
return
|
||||
}
|
||||
for _, dn := range domainsToNotify {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), dn.serviceID, string(dn.domain), true); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
}).WithError(err).Warn("failed to notify tunnel connection status")
|
||||
} else {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": dn.domain,
|
||||
}).Info("notified management about tunnel connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RemovePeer unregisters a domain from an account. The client is only stopped
|
||||
// when no domains are using it anymore.
|
||||
func (n *NetBird) RemovePeer(ctx context.Context, accountID types.AccountID, d domain.Domain) error {
|
||||
n.clientsMux.Lock()
|
||||
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
n.clientsMux.Unlock()
|
||||
n.logger.WithField("account_id", accountID).Debug("remove peer: account not found")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get domain info before deleting
|
||||
domInfo, domainExists := entry.domains[d]
|
||||
if !domainExists {
|
||||
n.clientsMux.Unlock()
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).Debug("remove peer: domain not registered")
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(entry.domains, d)
|
||||
|
||||
// If there are still domains using this client, keep it running
|
||||
if len(entry.domains) > 0 {
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
"remaining_domains": len(entry.domains),
|
||||
}).Debug("unregistered domain, client still in use")
|
||||
|
||||
// Notify this domain as disconnected
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// No more domains using this client, stop it
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).Info("stopping client, no more domains")
|
||||
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
delete(n.clients, accountID)
|
||||
n.clientsMux.Unlock()
|
||||
|
||||
// Notify disconnection before stopping
|
||||
if n.statusNotifier != nil {
|
||||
if err := n.statusNotifier.NotifyStatus(ctx, string(accountID), domInfo.serviceID, string(d), false); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"domain": d,
|
||||
}).WithError(err).Warn("failed to notify tunnel disconnection status")
|
||||
}
|
||||
}
|
||||
|
||||
transport.CloseIdleConnections()
|
||||
|
||||
if err := client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Warn("failed to stop netbird client")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper. It looks up the client for the account
|
||||
// specified in the request context and uses it to dial the backend.
|
||||
func (n *NetBird) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
accountID := AccountIDFromContext(req.Context())
|
||||
if accountID == "" {
|
||||
return nil, ErrNoAccountID
|
||||
}
|
||||
|
||||
// Copy references while holding lock, then unlock early to avoid blocking
|
||||
// other requests during the potentially slow RoundTrip.
|
||||
n.clientsMux.RLock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
n.clientsMux.RUnlock()
|
||||
return nil, fmt.Errorf("%w for account: %s", ErrNoPeerConnection, accountID)
|
||||
}
|
||||
client := entry.client
|
||||
transport := entry.transport
|
||||
n.clientsMux.RUnlock()
|
||||
|
||||
release, ok := entry.acquireInflight(req.URL.Host)
|
||||
defer release()
|
||||
if !ok {
|
||||
return nil, ErrTooManyInflight
|
||||
}
|
||||
|
||||
// Attempt to start the client, if the client is already running then
|
||||
// it will return an error that we ignore, if this hits a timeout then
|
||||
// this request is unprocessable.
|
||||
startCtx, cancel := context.WithTimeout(req.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := client.Start(startCtx); err != nil {
|
||||
if !errors.Is(err, embed.ErrClientAlreadyStarted) {
|
||||
return nil, fmt.Errorf("%w: %w", ErrClientStartFailed, err)
|
||||
}
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
resp, err := transport.RoundTrip(req)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
n.logger.Debugf("roundtrip: method=%s host=%s url=%s account=%s duration=%s err=%v",
|
||||
req.Method, req.Host, req.URL.String(), accountID, duration.Truncate(time.Millisecond), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n.logger.Debugf("roundtrip: method=%s host=%s url=%s account=%s status=%d duration=%s",
|
||||
req.Method, req.Host, req.URL.String(), accountID, resp.StatusCode, duration.Truncate(time.Millisecond))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// StopAll stops all clients.
|
||||
func (n *NetBird) StopAll(ctx context.Context) error {
|
||||
n.clientsMux.Lock()
|
||||
defer n.clientsMux.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
for accountID, entry := range n.clients {
|
||||
entry.transport.CloseIdleConnections()
|
||||
if err := entry.client.Stop(ctx); err != nil {
|
||||
n.logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
}).WithError(err).Warn("failed to stop netbird client during shutdown")
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
}
|
||||
maps.Clear(n.clients)
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// HasClient returns true if there is a client for the given account.
|
||||
func (n *NetBird) HasClient(accountID types.AccountID) bool {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
_, exists := n.clients[accountID]
|
||||
return exists
|
||||
}
|
||||
|
||||
// DomainCount returns the number of domains registered for the given account.
|
||||
// Returns 0 if the account has no client.
|
||||
func (n *NetBird) DomainCount(accountID types.AccountID) int {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return 0
|
||||
}
|
||||
return len(entry.domains)
|
||||
}
|
||||
|
||||
// ClientCount returns the total number of active clients.
|
||||
func (n *NetBird) ClientCount() int {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
return len(n.clients)
|
||||
}
|
||||
|
||||
// GetClient returns the embed.Client for the given account ID.
|
||||
func (n *NetBird) GetClient(accountID types.AccountID) (*embed.Client, bool) {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
entry, exists := n.clients[accountID]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
return entry.client, true
|
||||
}
|
||||
|
||||
// ListClientsForDebug returns information about all clients for debug purposes.
|
||||
func (n *NetBird) ListClientsForDebug() map[types.AccountID]ClientDebugInfo {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
|
||||
result := make(map[types.AccountID]ClientDebugInfo)
|
||||
for accountID, entry := range n.clients {
|
||||
domains := make(domain.List, 0, len(entry.domains))
|
||||
for d := range entry.domains {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
result[accountID] = ClientDebugInfo{
|
||||
AccountID: accountID,
|
||||
DomainCount: len(entry.domains),
|
||||
Domains: domains,
|
||||
HasClient: entry.client != nil,
|
||||
CreatedAt: entry.createdAt,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ListClientsForStartup returns all embed.Client instances for health checks.
|
||||
func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client {
|
||||
n.clientsMux.RLock()
|
||||
defer n.clientsMux.RUnlock()
|
||||
|
||||
result := make(map[types.AccountID]*embed.Client)
|
||||
for accountID, entry := range n.clients {
|
||||
if entry.client != nil {
|
||||
result[accountID] = entry.client
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// NewNetBird creates a new NetBird transport. Set wgPort to 0 for a random
|
||||
// OS-assigned port. A fixed port only works with single-account deployments;
|
||||
// multiple accounts will fail to bind the same port.
|
||||
func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &NetBird{
|
||||
mgmtAddr: mgmtAddr,
|
||||
proxyID: proxyID,
|
||||
proxyAddr: proxyAddr,
|
||||
wgPort: wgPort,
|
||||
logger: logger,
|
||||
clients: make(map[types.AccountID]*clientEntry),
|
||||
statusNotifier: notifier,
|
||||
mgmtClient: mgmtClient,
|
||||
transportCfg: loadTransportConfig(logger),
|
||||
}
|
||||
}
|
||||
|
||||
// WithAccountID adds the account ID to the context.
|
||||
func WithAccountID(ctx context.Context, accountID types.AccountID) context.Context {
|
||||
return context.WithValue(ctx, accountIDContextKey{}, accountID)
|
||||
}
|
||||
|
||||
// AccountIDFromContext retrieves the account ID from the context.
|
||||
func AccountIDFromContext(ctx context.Context) types.AccountID {
|
||||
v := ctx.Value(accountIDContextKey{})
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
accountID, ok := v.(types.AccountID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return accountID
|
||||
}
|
||||
107
proxy/internal/roundtrip/netbird_bench_test.go
Normal file
107
proxy/internal/roundtrip/netbird_bench_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// Simple benchmark for comparison with AddPeer contention.
|
||||
func BenchmarkHasClient(b *testing.B) {
|
||||
// Knobs for dialling in:
|
||||
initialClientCount := 100 // Size of initial peer map to generate.
|
||||
|
||||
nb := mockNetBird()
|
||||
|
||||
var target types.AccountID
|
||||
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(initialClientCount)))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
for i := range initialClientCount {
|
||||
id := types.AccountID(rand.Text())
|
||||
if int64(i) == targetIndex.Int64() {
|
||||
target = id
|
||||
}
|
||||
nb.clients[id] = &clientEntry{
|
||||
domains: map[domain.Domain]domainInfo{
|
||||
domain.Domain(rand.Text()): {
|
||||
serviceID: rand.Text(),
|
||||
},
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
started: true,
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
nb.HasClient(target)
|
||||
}
|
||||
})
|
||||
b.StopTimer()
|
||||
}
|
||||
|
||||
func BenchmarkHasClientDuringAddPeer(b *testing.B) {
|
||||
// Knobs for dialling in:
|
||||
initialClientCount := 100 // Size of initial peer map to generate.
|
||||
addPeerWorkers := 5 // Number of workers to concurrently call AddPeer.
|
||||
|
||||
nb := mockNetBird()
|
||||
|
||||
// Add random client entries to the netbird instance.
|
||||
// We're trying to test map lock contention, so starting with
|
||||
// a populated map should help with this.
|
||||
// Pick a random one to target for retrieval later.
|
||||
var target types.AccountID
|
||||
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(initialClientCount)))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
for i := range initialClientCount {
|
||||
id := types.AccountID(rand.Text())
|
||||
if int64(i) == targetIndex.Int64() {
|
||||
target = id
|
||||
}
|
||||
nb.clients[id] = &clientEntry{
|
||||
domains: map[domain.Domain]domainInfo{
|
||||
domain.Domain(rand.Text()): {
|
||||
serviceID: rand.Text(),
|
||||
},
|
||||
},
|
||||
createdAt: time.Now(),
|
||||
started: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Launch workers that continuously call AddPeer with new random accountIDs.
|
||||
var wg sync.WaitGroup
|
||||
for range addPeerWorkers {
|
||||
wg.Go(func() {
|
||||
for {
|
||||
if err := nb.AddPeer(b.Context(),
|
||||
types.AccountID(rand.Text()),
|
||||
domain.Domain(rand.Text()),
|
||||
rand.Text(),
|
||||
rand.Text()); err != nil {
|
||||
b.Log(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark calling HasClient during AddPeer contention.
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
nb.HasClient(target)
|
||||
}
|
||||
})
|
||||
b.StopTimer()
|
||||
}
|
||||
328
proxy/internal/roundtrip/netbird_test.go
Normal file
328
proxy/internal/roundtrip/netbird_test.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type mockMgmtClient struct{}
|
||||
|
||||
func (m *mockMgmtClient) CreateProxyPeer(_ context.Context, _ *proto.CreateProxyPeerRequest, _ ...grpc.CallOption) (*proto.CreateProxyPeerResponse, error) {
|
||||
return &proto.CreateProxyPeerResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
type mockStatusNotifier struct {
|
||||
mu sync.Mutex
|
||||
statuses []statusCall
|
||||
}
|
||||
|
||||
type statusCall struct {
|
||||
accountID string
|
||||
serviceID string
|
||||
domain string
|
||||
connected bool
|
||||
}
|
||||
|
||||
func (m *mockStatusNotifier) NotifyStatus(_ context.Context, accountID, serviceID, domain string, connected bool) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.statuses = append(m.statuses, statusCall{accountID, serviceID, domain, connected})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockStatusNotifier) calls() []statusCall {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return append([]statusCall{}, m.statuses...)
|
||||
}
|
||||
|
||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||
// It uses an invalid management URL to prevent real connections.
|
||||
func mockNetBird() *NetBird {
|
||||
return NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, nil, &mockMgmtClient{})
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Initially no client exists.
|
||||
assert.False(t, nb.HasClient(accountID), "should not have client before AddPeer")
|
||||
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||
|
||||
// Add first domain - this should create a new client.
|
||||
// Note: This will fail to actually connect since we use an invalid URL,
|
||||
// but the client entry should still be created.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, nb.HasClient(accountID), "should have client after AddPeer")
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_ReuseClientForSameAccount(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID))
|
||||
|
||||
// Add second domain for the same account - should reuse existing client.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2 after adding second domain")
|
||||
|
||||
// Add third domain.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, nb.DomainCount(accountID), "domain count should be 3 after adding third domain")
|
||||
|
||||
// Still only one client.
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_SeparateClientsForDifferentAccounts(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
account1 := types.AccountID("account-1")
|
||||
account2 := types.AccountID("account-2")
|
||||
|
||||
// Add domain for account 1.
|
||||
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add domain for account 2.
|
||||
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "setup-key-2", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both accounts should have their own clients.
|
||||
assert.True(t, nb.HasClient(account1), "account1 should have client")
|
||||
assert.True(t, nb.HasClient(account2), "account2 should have client")
|
||||
assert.Equal(t, 1, nb.DomainCount(account1), "account1 domain count should be 1")
|
||||
assert.Equal(t, 1, nb.DomainCount(account2), "account2 domain count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_KeepsClientWhenDomainsRemain(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add multiple domains.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "setup-key-1", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain3.test"), "setup-key-1", "proxy-3")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, nb.DomainCount(accountID))
|
||||
|
||||
// Remove one domain - client should remain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing one domain")
|
||||
assert.Equal(t, 2, nb.DomainCount(accountID), "domain count should be 2")
|
||||
|
||||
// Remove another domain - client should still remain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain2.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID), "client should remain after removing second domain")
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "domain count should be 1")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_RemovesClientWhenLastDomainRemoved(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add single domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
|
||||
// Remove the only domain - client should be removed.
|
||||
// Note: Stop() may fail since the client never actually connected,
|
||||
// but the entry should still be removed from the map.
|
||||
_ = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
|
||||
// After removing all domains, client should be gone.
|
||||
assert.False(t, nb.HasClient(accountID), "client should be removed after removing last domain")
|
||||
assert.Equal(t, 0, nb.DomainCount(accountID), "domain count should be 0")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NonExistentAccountIsNoop(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("nonexistent-account")
|
||||
|
||||
// Removing from non-existent account should not error.
|
||||
err := nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
assert.NoError(t, err, "removing from non-existent account should not error")
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NonExistentDomainIsNoop(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add one domain.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "setup-key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove non-existent domain - should not affect existing domain.
|
||||
err = nb.RemovePeer(context.Background(), accountID, domain.Domain("nonexistent.test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Original domain should still be registered.
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
assert.Equal(t, 1, nb.DomainCount(accountID), "original domain should remain")
|
||||
}
|
||||
|
||||
func TestWithAccountID_AndAccountIDFromContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := types.AccountID("test-account")
|
||||
|
||||
// Initially no account ID in context.
|
||||
retrieved := AccountIDFromContext(ctx)
|
||||
assert.True(t, retrieved == "", "should be empty when not set")
|
||||
|
||||
// Add account ID to context.
|
||||
ctx = WithAccountID(ctx, accountID)
|
||||
retrieved = AccountIDFromContext(ctx)
|
||||
assert.Equal(t, accountID, retrieved, "should retrieve the same account ID")
|
||||
}
|
||||
|
||||
func TestAccountIDFromContext_ReturnsEmptyForWrongType(t *testing.T) {
|
||||
// Create context with wrong type for account ID key.
|
||||
ctx := context.WithValue(context.Background(), accountIDContextKey{}, "wrong-type-string")
|
||||
|
||||
retrieved := AccountIDFromContext(ctx)
|
||||
assert.True(t, retrieved == "", "should return empty for wrong type")
|
||||
}
|
||||
|
||||
func TestNetBird_StopAll_StopsAllClients(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
account1 := types.AccountID("account-1")
|
||||
account2 := types.AccountID("account-2")
|
||||
account3 := types.AccountID("account-3")
|
||||
|
||||
// Add domains for multiple accounts.
|
||||
err := nb.AddPeer(context.Background(), account1, domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), account2, domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), account3, domain.Domain("domain3.test"), "key-3", "proxy-3")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 3, nb.ClientCount(), "should have 3 clients")
|
||||
|
||||
// Stop all clients.
|
||||
// Note: StopAll may return errors since clients never actually connected,
|
||||
// but the clients should still be removed from the map.
|
||||
_ = nb.StopAll(context.Background())
|
||||
|
||||
assert.Equal(t, 0, nb.ClientCount(), "should have 0 clients after StopAll")
|
||||
assert.False(t, nb.HasClient(account1), "account1 should not have client")
|
||||
assert.False(t, nb.HasClient(account2), "account2 should not have client")
|
||||
assert.False(t, nb.HasClient(account3), "account3 should not have client")
|
||||
}
|
||||
|
||||
func TestNetBird_ClientCount(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
|
||||
assert.Equal(t, 0, nb.ClientCount(), "should start with 0 clients")
|
||||
|
||||
// Add clients for different accounts.
|
||||
err := nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1.test"), "key-1", "proxy-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, nb.ClientCount())
|
||||
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-2"), domain.Domain("domain2.test"), "key-2", "proxy-2")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.ClientCount())
|
||||
|
||||
// Adding domain to existing account should not increase count.
|
||||
err = nb.AddPeer(context.Background(), types.AccountID("account-1"), domain.Domain("domain1b.test"), "key-1", "proxy-1b")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, nb.ClientCount(), "adding domain to existing account should not increase client count")
|
||||
}
|
||||
|
||||
func TestNetBird_RoundTrip_RequiresAccountIDInContext(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
|
||||
// Create a request without account ID in context.
|
||||
req, err := http.NewRequest("GET", "http://example.com/", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// RoundTrip should fail because no account ID in context.
|
||||
_, err = nb.RoundTrip(req) //nolint:bodyclose
|
||||
require.ErrorIs(t, err, ErrNoAccountID)
|
||||
}
|
||||
|
||||
func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
|
||||
nb := mockNetBird()
|
||||
accountID := types.AccountID("nonexistent-account")
|
||||
|
||||
// Create a request with account ID but no client exists.
|
||||
req, err := http.NewRequest("GET", "http://example.com/", nil)
|
||||
require.NoError(t, err)
|
||||
req = req.WithContext(WithAccountID(req.Context(), accountID))
|
||||
|
||||
// RoundTrip should fail because no client for this account.
|
||||
_, err = nb.RoundTrip(req) //nolint:bodyclose // Error case, no response body
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no peer connection found for account")
|
||||
}
|
||||
|
||||
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
// Add first domain — creates a new client entry.
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually mark client as started to simulate background startup completing.
|
||||
nb.clientsMux.Lock()
|
||||
nb.clients[accountID].started = true
|
||||
nb.clientsMux.Unlock()
|
||||
|
||||
// Add second domain — should notify immediately since client is already started.
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, string(accountID), calls[0].accountID)
|
||||
assert.Equal(t, "svc-2", calls[0].serviceID)
|
||||
assert.Equal(t, "domain2.test", calls[0].domain)
|
||||
assert.True(t, calls[0].connected)
|
||||
}
|
||||
|
||||
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||
notifier := &mockStatusNotifier{}
|
||||
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
|
||||
accountID := types.AccountID("account-1")
|
||||
|
||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||
require.NoError(t, err)
|
||||
err = nb.AddPeer(context.Background(), accountID, domain.Domain("domain2.test"), "key-1", "svc-2")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove one domain — client stays, but disconnection notification fires.
|
||||
err = nb.RemovePeer(context.Background(), accountID, "domain1.test")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, nb.HasClient(accountID))
|
||||
|
||||
calls := notifier.calls()
|
||||
require.Len(t, calls, 1)
|
||||
assert.Equal(t, "domain1.test", calls[0].domain)
|
||||
assert.False(t, calls[0].connected)
|
||||
}
|
||||
152
proxy/internal/roundtrip/transport.go
Normal file
152
proxy/internal/roundtrip/transport.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package roundtrip
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Environment variable names for tuning the backend HTTP transport.
|
||||
const (
|
||||
EnvMaxIdleConns = "NB_PROXY_MAX_IDLE_CONNS"
|
||||
EnvMaxIdleConnsPerHost = "NB_PROXY_MAX_IDLE_CONNS_PER_HOST"
|
||||
EnvMaxConnsPerHost = "NB_PROXY_MAX_CONNS_PER_HOST"
|
||||
EnvIdleConnTimeout = "NB_PROXY_IDLE_CONN_TIMEOUT"
|
||||
EnvTLSHandshakeTimeout = "NB_PROXY_TLS_HANDSHAKE_TIMEOUT"
|
||||
EnvExpectContinueTimeout = "NB_PROXY_EXPECT_CONTINUE_TIMEOUT"
|
||||
EnvResponseHeaderTimeout = "NB_PROXY_RESPONSE_HEADER_TIMEOUT"
|
||||
EnvWriteBufferSize = "NB_PROXY_WRITE_BUFFER_SIZE"
|
||||
EnvReadBufferSize = "NB_PROXY_READ_BUFFER_SIZE"
|
||||
EnvDisableCompression = "NB_PROXY_DISABLE_COMPRESSION"
|
||||
EnvMaxInflight = "NB_PROXY_MAX_INFLIGHT"
|
||||
)
|
||||
|
||||
// transportConfig holds tunable parameters for the per-account HTTP transport.
|
||||
type transportConfig struct {
|
||||
maxIdleConns int
|
||||
maxIdleConnsPerHost int
|
||||
maxConnsPerHost int
|
||||
idleConnTimeout time.Duration
|
||||
tlsHandshakeTimeout time.Duration
|
||||
expectContinueTimeout time.Duration
|
||||
responseHeaderTimeout time.Duration
|
||||
writeBufferSize int
|
||||
readBufferSize int
|
||||
disableCompression bool
|
||||
// maxInflight limits per-backend concurrent requests. 0 means unlimited.
|
||||
maxInflight int
|
||||
}
|
||||
|
||||
func defaultTransportConfig() transportConfig {
|
||||
return transportConfig{
|
||||
maxIdleConns: 100,
|
||||
maxIdleConnsPerHost: 100,
|
||||
maxConnsPerHost: 0, // unlimited
|
||||
idleConnTimeout: 90 * time.Second,
|
||||
tlsHandshakeTimeout: 10 * time.Second,
|
||||
expectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func loadTransportConfig(logger *log.Logger) transportConfig {
|
||||
cfg := defaultTransportConfig()
|
||||
|
||||
if v, ok := envInt(EnvMaxIdleConns, logger); ok {
|
||||
cfg.maxIdleConns = v
|
||||
}
|
||||
if v, ok := envInt(EnvMaxIdleConnsPerHost, logger); ok {
|
||||
cfg.maxIdleConnsPerHost = v
|
||||
}
|
||||
if v, ok := envInt(EnvMaxConnsPerHost, logger); ok {
|
||||
cfg.maxConnsPerHost = v
|
||||
}
|
||||
if v, ok := envDuration(EnvIdleConnTimeout, logger); ok {
|
||||
cfg.idleConnTimeout = v
|
||||
}
|
||||
if v, ok := envDuration(EnvTLSHandshakeTimeout, logger); ok {
|
||||
cfg.tlsHandshakeTimeout = v
|
||||
}
|
||||
if v, ok := envDuration(EnvExpectContinueTimeout, logger); ok {
|
||||
cfg.expectContinueTimeout = v
|
||||
}
|
||||
if v, ok := envDuration(EnvResponseHeaderTimeout, logger); ok {
|
||||
cfg.responseHeaderTimeout = v
|
||||
}
|
||||
if v, ok := envInt(EnvWriteBufferSize, logger); ok {
|
||||
cfg.writeBufferSize = v
|
||||
}
|
||||
if v, ok := envInt(EnvReadBufferSize, logger); ok {
|
||||
cfg.readBufferSize = v
|
||||
}
|
||||
if v, ok := envBool(EnvDisableCompression, logger); ok {
|
||||
cfg.disableCompression = v
|
||||
}
|
||||
if v, ok := envInt(EnvMaxInflight, logger); ok {
|
||||
cfg.maxInflight = v
|
||||
}
|
||||
|
||||
logger.WithFields(log.Fields{
|
||||
"max_idle_conns": cfg.maxIdleConns,
|
||||
"max_idle_conns_per_host": cfg.maxIdleConnsPerHost,
|
||||
"max_conns_per_host": cfg.maxConnsPerHost,
|
||||
"idle_conn_timeout": cfg.idleConnTimeout,
|
||||
"tls_handshake_timeout": cfg.tlsHandshakeTimeout,
|
||||
"expect_continue_timeout": cfg.expectContinueTimeout,
|
||||
"response_header_timeout": cfg.responseHeaderTimeout,
|
||||
"write_buffer_size": cfg.writeBufferSize,
|
||||
"read_buffer_size": cfg.readBufferSize,
|
||||
"disable_compression": cfg.disableCompression,
|
||||
"max_inflight": cfg.maxInflight,
|
||||
}).Debug("backend transport configuration")
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func envInt(key string, logger *log.Logger) (int, bool) {
|
||||
s := os.Getenv(key)
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
v, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
logger.Warnf("failed to parse %s=%q as int: %v", key, s, err)
|
||||
return 0, false
|
||||
}
|
||||
if v < 0 {
|
||||
logger.Warnf("ignoring negative value for %s=%d", key, v)
|
||||
return 0, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
func envDuration(key string, logger *log.Logger) (time.Duration, bool) {
|
||||
s := os.Getenv(key)
|
||||
if s == "" {
|
||||
return 0, false
|
||||
}
|
||||
v, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
logger.Warnf("failed to parse %s=%q as duration: %v", key, s, err)
|
||||
return 0, false
|
||||
}
|
||||
if v < 0 {
|
||||
logger.Warnf("ignoring negative value for %s=%s", key, v)
|
||||
return 0, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
|
||||
func envBool(key string, logger *log.Logger) (bool, bool) {
|
||||
s := os.Getenv(key)
|
||||
if s == "" {
|
||||
return false, false
|
||||
}
|
||||
v, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
logger.Warnf("failed to parse %s=%q as bool: %v", key, s, err)
|
||||
return false, false
|
||||
}
|
||||
return v, true
|
||||
}
|
||||
5
proxy/internal/types/types.go
Normal file
5
proxy/internal/types/types.go
Normal file
@@ -0,0 +1,5 @@
|
||||
// Package types defines common types used across the proxy package.
|
||||
package types
|
||||
|
||||
// AccountID represents a unique identifier for a NetBird account.
|
||||
type AccountID string
|
||||
Reference in New Issue
Block a user