Add cert hot reload and cert file locking

Adds file-watching certificate hot reload, cross-replica ACME
certificate lock coordination via flock (Unix) and Kubernetes lease
objects.
This commit is contained in:
Viktor Liu
2026-02-09 20:17:05 +08:00
parent be5f30225a
commit fd442138e6
14 changed files with 1606 additions and 23 deletions

View File

@@ -45,6 +45,8 @@ var (
oidcScopes string
forwardedProto string
trustedProxies string
certFile string
certKeyFile string
)
var rootCmd = &cobra.Command{
@@ -74,6 +76,8 @@ func init() {
rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
rootCmd.Flags().StringVar(&forwardedProto, "forwarded-proto", envStringOrDefault("NB_PROXY_FORWARDED_PROTO", "auto"), "X-Forwarded-Proto value for backends: auto, http, or https")
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')")
rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory")
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
}
// Execute runs the root command.
@@ -127,6 +131,8 @@ func runServer(cmd *cobra.Command, args []string) error {
ProxyURL: proxyURL,
ProxyToken: proxyToken,
CertificateDirectory: certDir,
CertificateFile: certFile,
CertificateKeyFile: certKeyFile,
GenerateACMECertificates: acmeCerts,
ACMEChallengeAddress: acmeAddr,
ACMEDirectory: acmeDir,

View File

@@ -0,0 +1,99 @@
package acme
import (
"context"
"path/filepath"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/flock"
"github.com/netbirdio/netbird/proxy/internal/k8s"
)
// certLocker provides distributed mutual exclusion for certificate operations.
// Implementations must be safe for concurrent use from multiple goroutines.
type certLocker interface {
// Lock acquires an exclusive lock for the given domain.
// It blocks until the lock is acquired, the context is cancelled, or an
// unrecoverable error occurs. The returned function releases the lock;
// callers must call it exactly once when the critical section is complete.
Lock(ctx context.Context, domain string) (unlock func(), err error)
}
// CertLockMethod controls how ACME certificate locks are coordinated.
type CertLockMethod string
const (
// CertLockAuto detects the environment and selects k8s-lease if running
// in a Kubernetes pod, otherwise flock.
CertLockAuto CertLockMethod = "auto"
// CertLockFlock uses advisory file locks via flock(2).
CertLockFlock CertLockMethod = "flock"
// CertLockK8sLease uses Kubernetes coordination Leases.
CertLockK8sLease CertLockMethod = "k8s-lease"
)
func newCertLocker(method CertLockMethod, certDir string, logger *log.Logger) certLocker {
if logger == nil {
logger = log.StandardLogger()
}
if method == "" || method == CertLockAuto {
if k8s.InCluster() {
method = CertLockK8sLease
} else {
method = CertLockFlock
}
logger.Infof("auto-detected cert lock method: %s", method)
}
switch method {
case CertLockK8sLease:
locker, err := newK8sLeaseLocker(logger)
if err != nil {
logger.Warnf("create k8s lease locker, falling back to flock: %v", err)
return newFlockLocker(certDir, logger)
}
logger.Infof("using k8s lease locker in namespace %s", locker.client.Namespace())
return locker
default:
logger.Infof("using flock cert locker in %s", certDir)
return newFlockLocker(certDir, logger)
}
}
type flockLocker struct {
certDir string
logger *log.Logger
}
func newFlockLocker(certDir string, logger *log.Logger) *flockLocker {
return &flockLocker{certDir: certDir, logger: logger}
}
// Lock acquires an advisory file lock for the given domain.
func (l *flockLocker) Lock(ctx context.Context, domain string) (func(), error) {
lockPath := filepath.Join(l.certDir, domain+".lock")
lockFile, err := flock.Lock(ctx, lockPath)
if err != nil {
return nil, err
}
// nil lockFile means locking is not supported (non-unix).
if lockFile == nil {
return func() {}, nil
}
return func() {
if err := flock.Unlock(lockFile); err != nil {
l.logger.Debugf("release cert lock for domain %q: %v", domain, err)
}
}, nil
}
type noopLocker struct{}
// Lock is a no-op that always succeeds immediately.
func (noopLocker) Lock(context.Context, string) (func(), error) {
return func() {}, nil
}

View File

@@ -0,0 +1,197 @@
package acme
import (
"context"
"errors"
"fmt"
"os"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/k8s"
)
const (
// leaseDurationSec is the Kubernetes Lease TTL. If the holder crashes without
// releasing the lock, other replicas must wait this long before taking over.
// This is intentionally generous: in the worst case two replicas may both
// issue an ACME request for the same domain, which is harmless (the CA
// deduplicates and the cache converges).
leaseDurationSec = 300
retryBaseBackoff = 500 * time.Millisecond
retryMaxBackoff = 10 * time.Second
)
type k8sLeaseLocker struct {
client *k8s.LeaseClient
identity string
logger *log.Logger
}
func newK8sLeaseLocker(logger *log.Logger) (*k8sLeaseLocker, error) {
client, err := k8s.NewLeaseClient()
if err != nil {
return nil, fmt.Errorf("create k8s lease client: %w", err)
}
identity, err := os.Hostname()
if err != nil {
return nil, fmt.Errorf("get hostname: %w", err)
}
return &k8sLeaseLocker{
client: client,
identity: identity,
logger: logger,
}, nil
}
// Lock acquires a Kubernetes Lease for the given domain using optimistic
// concurrency. It retries with exponential backoff until the lease is
// acquired or the context is cancelled.
func (l *k8sLeaseLocker) Lock(ctx context.Context, domain string) (func(), error) {
leaseName := k8s.LeaseNameForDomain(domain)
backoff := retryBaseBackoff
for {
acquired, err := l.tryAcquire(ctx, leaseName, domain)
if err != nil {
return nil, fmt.Errorf("acquire lease %s for %q: %w", leaseName, domain, err)
}
if acquired {
l.logger.Debugf("k8s lease %s acquired for domain %q", leaseName, domain)
return l.unlockFunc(leaseName, domain), nil
}
l.logger.Debugf("k8s lease %s held by another replica, retrying in %s", leaseName, backoff)
timer := time.NewTimer(backoff)
select {
case <-ctx.Done():
timer.Stop()
return nil, ctx.Err()
case <-timer.C:
}
backoff *= 2
if backoff > retryMaxBackoff {
backoff = retryMaxBackoff
}
}
}
// tryAcquire attempts to create or take over a Lease. Returns (true, nil)
// on success, (false, nil) if the lease is held and not stale, or an error.
func (l *k8sLeaseLocker) tryAcquire(ctx context.Context, name, domain string) (bool, error) {
existing, err := l.client.Get(ctx, name)
if err != nil {
return false, err
}
now := k8s.MicroTime{Time: time.Now().UTC()}
dur := int32(leaseDurationSec)
if existing == nil {
lease := &k8s.Lease{
Metadata: k8s.LeaseMetadata{
Name: name,
Annotations: map[string]string{
"netbird.io/domain": domain,
},
},
Spec: k8s.LeaseSpec{
HolderIdentity: &l.identity,
LeaseDurationSeconds: &dur,
AcquireTime: &now,
RenewTime: &now,
},
}
if _, err := l.client.Create(ctx, lease); errors.Is(err, k8s.ErrConflict) {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}
if !l.canTakeover(existing) {
return false, nil
}
existing.Spec.HolderIdentity = &l.identity
existing.Spec.LeaseDurationSeconds = &dur
existing.Spec.AcquireTime = &now
existing.Spec.RenewTime = &now
if _, err := l.client.Update(ctx, existing); errors.Is(err, k8s.ErrConflict) {
return false, nil
} else if err != nil {
return false, err
}
return true, nil
}
// canTakeover returns true if the lease is free (no holder) or stale
// (renewTime + leaseDuration has passed).
func (l *k8sLeaseLocker) canTakeover(lease *k8s.Lease) bool {
holder := lease.Spec.HolderIdentity
if holder == nil || *holder == "" {
return true
}
// We already hold it (e.g. from a previous crashed attempt).
if *holder == l.identity {
return true
}
if lease.Spec.RenewTime == nil || lease.Spec.LeaseDurationSeconds == nil {
return true
}
expiry := lease.Spec.RenewTime.Add(time.Duration(*lease.Spec.LeaseDurationSeconds) * time.Second)
if time.Now().After(expiry) {
l.logger.Infof("k8s lease %s held by %q is stale (expired %s ago), taking over",
lease.Metadata.Name, *holder, time.Since(expiry).Round(time.Second))
return true
}
return false
}
// unlockFunc returns a closure that releases the lease by clearing the holder.
func (l *k8sLeaseLocker) unlockFunc(name, domain string) func() {
return func() {
// Use a fresh context: the parent may already be cancelled.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// Re-GET to get current resourceVersion (ours may be stale if
// the lock was held for a long time and something updated it).
current, err := l.client.Get(ctx, name)
if err != nil {
l.logger.Debugf("release k8s lease %s for %q: get: %v", name, domain, err)
return
}
if current == nil {
return
}
// Only clear if we're still the holder.
if current.Spec.HolderIdentity == nil || *current.Spec.HolderIdentity != l.identity {
l.logger.Debugf("k8s lease %s for %q: holder changed to %v, skip release",
name, domain, current.Spec.HolderIdentity)
return
}
empty := ""
current.Spec.HolderIdentity = &empty
current.Spec.AcquireTime = nil
current.Spec.RenewTime = nil
if _, err := l.client.Update(ctx, current); err != nil {
l.logger.Debugf("release k8s lease %s for %q: update: %v", name, domain, err)
}
}
}

View File

@@ -0,0 +1,65 @@
package acme
import (
"context"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFlockLockerRoundTrip(t *testing.T) {
dir := t.TempDir()
locker := newFlockLocker(dir, nil)
unlock, err := locker.Lock(context.Background(), "example.com")
require.NoError(t, err)
require.NotNil(t, unlock)
// Lock file should exist.
assert.FileExists(t, filepath.Join(dir, "example.com.lock"))
unlock()
}
func TestNoopLocker(t *testing.T) {
locker := noopLocker{}
unlock, err := locker.Lock(context.Background(), "example.com")
require.NoError(t, err)
require.NotNil(t, unlock)
unlock()
}
func TestNewCertLockerDefaultsToFlock(t *testing.T) {
dir := t.TempDir()
// t.Setenv registers cleanup to restore the original value.
// os.Unsetenv is needed because the production code uses LookupEnv,
// which distinguishes "empty" from "not set".
t.Setenv("KUBERNETES_SERVICE_HOST", "")
os.Unsetenv("KUBERNETES_SERVICE_HOST")
locker := newCertLocker(CertLockAuto, dir, nil)
_, ok := locker.(*flockLocker)
assert.True(t, ok, "auto without k8s env should select flockLocker")
}
func TestNewCertLockerExplicitFlock(t *testing.T) {
dir := t.TempDir()
locker := newCertLocker(CertLockFlock, dir, nil)
_, ok := locker.(*flockLocker)
assert.True(t, ok, "explicit flock should select flockLocker")
}
func TestNewCertLockerK8sFallsBackToFlock(t *testing.T) {
dir := t.TempDir()
// k8s-lease without SA files should fall back to flock.
locker := newCertLocker(CertLockK8sLease, dir, nil)
_, ok := locker.(*flockLocker)
assert.True(t, ok, "k8s-lease without SA should fall back to flockLocker")
}

View File

@@ -23,9 +23,14 @@ type certificateNotifier interface {
NotifyCertificateIssued(ctx context.Context, accountID, reverseProxyID, domain string) error
}
// Manager wraps autocert.Manager with domain tracking and cross-replica
// coordination via a pluggable locking strategy. The locker prevents
// duplicate ACME requests when multiple replicas share a certificate cache.
type Manager struct {
*autocert.Manager
certDir string
locker certLocker
domainsMux sync.RWMutex
domains map[string]struct {
accountID string
@@ -36,11 +41,16 @@ type Manager struct {
logger *log.Logger
}
func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger) *Manager {
// NewManager creates a new ACME certificate manager. The certDir is used
// for caching certificates. The lockMethod controls cross-replica
// coordination strategy (see CertLockMethod constants).
func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager {
if logger == nil {
logger = log.StandardLogger()
}
mgr := &Manager{
certDir: certDir,
locker: newCertLocker(lockMethod, certDir, logger),
domains: make(map[string]struct {
accountID string
reverseProxyID string
@@ -59,12 +69,15 @@ func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *l
return mgr
}
func (mgr *Manager) hostPolicy(ctx context.Context, domain string) error {
func (mgr *Manager) hostPolicy(_ context.Context, host string) error {
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
mgr.domainsMux.RLock()
_, exists := mgr.domains[domain]
_, exists := mgr.domains[host]
mgr.domainsMux.RUnlock()
if !exists {
return fmt.Errorf("unknown domain %q", domain)
return fmt.Errorf("unknown domain %q", host)
}
return nil
}
@@ -84,17 +97,31 @@ func (mgr *Manager) AddDomain(domain, accountID, reverseProxyID string) {
}
// prefetchCertificate proactively triggers certificate generation for a domain.
// It acquires a distributed lock to prevent multiple replicas from issuing
// duplicate ACME requests. The second replica will block until the first
// finishes, then find the certificate in the cache.
func (mgr *Manager) prefetchCertificate(domain string) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
mgr.logger.Infof("acquiring cert lock for domain %q", domain)
lockStart := time.Now()
unlock, err := mgr.locker.Lock(ctx, domain)
if err != nil {
mgr.logger.Warnf("acquire cert lock for domain %q, proceeding without lock: %v", domain, err)
} else {
mgr.logger.Infof("acquired cert lock for domain %q in %s", domain, time.Since(lockStart))
defer unlock()
}
hello := &tls.ClientHelloInfo{
ServerName: domain,
Conn: &dummyConn{ctx: ctx},
}
mgr.logger.Infof("prefetching certificate for domain %q", domain)
start := time.Now()
cert, err := mgr.GetCertificate(hello)
elapsed := time.Since(start)
if err != nil {
mgr.logger.Warnf("prefetch certificate for domain %q: %v", domain, err)
return
@@ -102,11 +129,18 @@ func (mgr *Manager) prefetchCertificate(domain string) {
now := time.Now()
if cert != nil && cert.Leaf != nil {
mgr.logCertificateDetails(domain, cert.Leaf, now)
leaf := cert.Leaf
mgr.logger.Infof("certificate for domain %q ready in %s: serial=%s SANs=%v notAfter=%s",
domain, elapsed.Round(time.Millisecond),
leaf.SerialNumber.Text(16),
leaf.DNSNames,
leaf.NotAfter.UTC().Format(time.RFC3339),
)
mgr.logCertificateDetails(domain, leaf, now)
} else {
mgr.logger.Infof("certificate for domain %q ready in %s", domain, elapsed.Round(time.Millisecond))
}
mgr.logger.Infof("certificate for domain %q is ready", domain)
mgr.domainsMux.RLock()
info, exists := mgr.domains[domain]
mgr.domainsMux.RUnlock()
@@ -120,18 +154,12 @@ func (mgr *Manager) prefetchCertificate(domain string) {
// logCertificateDetails logs certificate validity and SCT timestamps.
func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate, now time.Time) {
mgr.logger.Infof("certificate for %q: NotBefore=%v, NotAfter=%v, now=%v",
domain, cert.NotBefore.UTC(), cert.NotAfter.UTC(), now.UTC())
if cert.NotBefore.After(now) {
mgr.logger.Warnf("certificate for %q NotBefore is in the future by %v", domain, cert.NotBefore.Sub(now))
} else {
mgr.logger.Infof("certificate for %q NotBefore is %v in the past", domain, now.Sub(cert.NotBefore))
}
sctTimestamps := mgr.parseSCTTimestamps(cert)
if len(sctTimestamps) == 0 {
mgr.logger.Warnf("certificate for %q has no embedded SCTs", domain)
return
}
@@ -140,7 +168,7 @@ func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate,
mgr.logger.Warnf("certificate for %q SCT[%d] timestamp is in the future: %v (by %v)",
domain, i, sctTime.UTC(), sctTime.Sub(now))
} else {
mgr.logger.Infof("certificate for %q SCT[%d] timestamp: %v (%v in the past)",
mgr.logger.Debugf("certificate for %q SCT[%d] timestamp: %v (%v in the past)",
domain, i, sctTime.UTC(), now.Sub(sctTime))
}
}

View File

@@ -0,0 +1,61 @@
package acme
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHostPolicy(t *testing.T) {
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "")
mgr.AddDomain("example.com", "acc1", "rp1")
tests := []struct {
name string
host string
wantErr bool
}{
{
name: "exact domain match",
host: "example.com",
},
{
name: "domain with port",
host: "example.com:443",
},
{
name: "unknown domain",
host: "unknown.com",
wantErr: true,
},
{
name: "unknown domain with port",
host: "unknown.com:443",
wantErr: true,
},
{
name: "empty host",
host: "",
wantErr: true,
},
{
name: "port only",
host: ":443",
wantErr: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := mgr.hostPolicy(context.Background(), tc.host)
if tc.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown domain")
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,279 @@
// Package certwatch watches TLS certificate files on disk and provides
// a hot-reloading GetCertificate callback for tls.Config.
package certwatch
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"path/filepath"
"sync"
"time"
"github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus"
)
const (
defaultPollInterval = 30 * time.Second
debounceDelay = 500 * time.Millisecond
)
// Watcher monitors TLS certificate files on disk and caches the loaded
// certificate in memory. It detects changes via fsnotify (with a polling
// fallback for filesystems like NFS that lack inotify support) and
// reloads the certificate pair automatically.
type Watcher struct {
certPath string
keyPath string
mu sync.RWMutex
cert *tls.Certificate
leaf *x509.Certificate
pollInterval time.Duration
logger *log.Logger
}
// NewWatcher creates a Watcher that monitors the given cert and key files.
// It performs an initial load of the certificate and returns an error
// if the initial load fails.
func NewWatcher(certPath, keyPath string, logger *log.Logger) (*Watcher, error) {
if logger == nil {
logger = log.StandardLogger()
}
w := &Watcher{
certPath: certPath,
keyPath: keyPath,
pollInterval: defaultPollInterval,
logger: logger,
}
if err := w.reload(); err != nil {
return nil, fmt.Errorf("initial certificate load: %w", err)
}
return w, nil
}
// GetCertificate returns the current in-memory certificate.
// It is safe for concurrent use and compatible with tls.Config.GetCertificate.
func (w *Watcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
w.mu.RLock()
defer w.mu.RUnlock()
return w.cert, nil
}
// Watch starts watching for certificate file changes. It blocks until
// ctx is cancelled. It uses fsnotify for immediate detection and falls
// back to polling if fsnotify is unavailable (e.g. on NFS).
// Even with fsnotify active, a periodic poll runs as a safety net.
func (w *Watcher) Watch(ctx context.Context) {
// Watch the parent directory rather than individual files. Some volume
// mounts use an atomic symlink swap (..data -> timestamped dir), so
// watching the parent directory catches the link replacement.
certDir := filepath.Dir(w.certPath)
keyDir := filepath.Dir(w.keyPath)
watcher, err := fsnotify.NewWatcher()
if err != nil {
w.logger.Warnf("fsnotify unavailable, using polling only: %v", err)
w.pollLoop(ctx)
return
}
defer func() {
if err := watcher.Close(); err != nil {
w.logger.Debugf("close fsnotify watcher: %v", err)
}
}()
if err := watcher.Add(certDir); err != nil {
w.logger.Warnf("fsnotify watch on %s failed, using polling only: %v", certDir, err)
w.pollLoop(ctx)
return
}
if keyDir != certDir {
if err := watcher.Add(keyDir); err != nil {
w.logger.Warnf("fsnotify watch on %s failed: %v", keyDir, err)
}
}
w.logger.Infof("watching certificate files in %s", certDir)
w.fsnotifyLoop(ctx, watcher)
}
func (w *Watcher) fsnotifyLoop(ctx context.Context, watcher *fsnotify.Watcher) {
certBase := filepath.Base(w.certPath)
keyBase := filepath.Base(w.keyPath)
var debounce *time.Timer
defer func() {
if debounce != nil {
debounce.Stop()
}
}()
// Periodic poll as a safety net for missed fsnotify events.
pollTicker := time.NewTicker(w.pollInterval)
defer pollTicker.Stop()
for {
select {
case <-ctx.Done():
return
case event, ok := <-watcher.Events:
if !ok {
return
}
base := filepath.Base(event.Name)
if !isRelevantFile(base, certBase, keyBase) {
w.logger.Debugf("fsnotify: ignoring event %s on %s", event.Op, event.Name)
continue
}
if !event.Has(fsnotify.Create) && !event.Has(fsnotify.Write) && !event.Has(fsnotify.Rename) {
w.logger.Debugf("fsnotify: ignoring op %s on %s", event.Op, base)
continue
}
w.logger.Debugf("fsnotify: detected %s on %s, scheduling reload", event.Op, base)
// Debounce: cert-manager may write cert and key as separate
// operations. Wait briefly to load both at once.
if debounce != nil {
debounce.Stop()
}
debounce = time.AfterFunc(debounceDelay, func() {
if ctx.Err() != nil {
return
}
w.tryReload()
})
case err, ok := <-watcher.Errors:
if !ok {
return
}
w.logger.Warnf("fsnotify error: %v", err)
case <-pollTicker.C:
w.tryReload()
}
}
}
func (w *Watcher) pollLoop(ctx context.Context) {
ticker := time.NewTicker(w.pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
w.tryReload()
}
}
}
// reload loads the certificate from disk and updates the in-memory cache.
func (w *Watcher) reload() error {
cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath)
if err != nil {
return err
}
// Parse the leaf for comparison on subsequent reloads.
if cert.Leaf == nil && len(cert.Certificate) > 0 {
leaf, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return fmt.Errorf("parse leaf certificate: %w", err)
}
cert.Leaf = leaf
}
w.mu.Lock()
w.cert = &cert
w.leaf = cert.Leaf
w.mu.Unlock()
w.logCertDetails("loaded certificate", cert.Leaf)
return nil
}
// tryReload attempts to reload the certificate. It skips the update
// if the certificate on disk is identical to the one in memory (same
// serial number and issuer) to avoid redundant log noise.
func (w *Watcher) tryReload() {
cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath)
if err != nil {
w.logger.Warnf("reload certificate: %v", err)
return
}
if cert.Leaf == nil && len(cert.Certificate) > 0 {
leaf, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
w.logger.Warnf("parse reloaded leaf certificate: %v", err)
return
}
cert.Leaf = leaf
}
w.mu.Lock()
if w.leaf != nil && cert.Leaf != nil &&
w.leaf.SerialNumber.Cmp(cert.Leaf.SerialNumber) == 0 &&
w.leaf.Issuer.CommonName == cert.Leaf.Issuer.CommonName {
w.mu.Unlock()
return
}
prev := w.leaf
w.cert = &cert
w.leaf = cert.Leaf
w.mu.Unlock()
w.logCertChange(prev, cert.Leaf)
}
func (w *Watcher) logCertDetails(msg string, leaf *x509.Certificate) {
if leaf == nil {
w.logger.Info(msg)
return
}
w.logger.Infof("%s: subject=%q serial=%s SANs=%v notAfter=%s",
msg,
leaf.Subject.CommonName,
leaf.SerialNumber.Text(16),
leaf.DNSNames,
leaf.NotAfter.UTC().Format(time.RFC3339),
)
}
func (w *Watcher) logCertChange(prev, next *x509.Certificate) {
if prev == nil || next == nil {
w.logCertDetails("certificate reloaded from disk", next)
return
}
w.logger.Infof("certificate reloaded from disk: subject=%q -> %q serial=%s -> %s notAfter=%s -> %s",
prev.Subject.CommonName, next.Subject.CommonName,
prev.SerialNumber.Text(16), next.SerialNumber.Text(16),
prev.NotAfter.UTC().Format(time.RFC3339), next.NotAfter.UTC().Format(time.RFC3339),
)
}
// isRelevantFile returns true if the changed file name is one we care about.
// This includes the cert/key files themselves and the ..data symlink used
// by atomic volume mounts.
func isRelevantFile(changed, certBase, keyBase string) bool {
return changed == certBase || changed == keyBase || changed == "..data"
}

View File

@@ -0,0 +1,292 @@
package certwatch
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func generateSelfSignedCert(t *testing.T, serial int64) (certPEM, keyPEM []byte) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(serial),
Subject: pkix.Name{CommonName: "test"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
keyDER, err := x509.MarshalECPrivateKey(key)
require.NoError(t, err)
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return certPEM, keyPEM
}
func writeCert(t *testing.T, dir string, certPEM, keyPEM []byte) {
t.Helper()
require.NoError(t, os.WriteFile(filepath.Join(dir, "tls.crt"), certPEM, 0o600))
require.NoError(t, os.WriteFile(filepath.Join(dir, "tls.key"), keyPEM, 0o600))
}
func TestNewWatcher(t *testing.T) {
dir := t.TempDir()
certPEM, keyPEM := generateSelfSignedCert(t, 1)
writeCert(t, dir, certPEM, keyPEM)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
cert, err := w.GetCertificate(nil)
require.NoError(t, err)
require.NotNil(t, cert)
assert.Equal(t, int64(1), cert.Leaf.SerialNumber.Int64())
}
func TestNewWatcherMissingFiles(t *testing.T) {
dir := t.TempDir()
_, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
assert.Error(t, err)
}
func TestReload(t *testing.T) {
dir := t.TempDir()
certPEM1, keyPEM1 := generateSelfSignedCert(t, 100)
writeCert(t, dir, certPEM1, keyPEM1)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
cert1, err := w.GetCertificate(nil)
require.NoError(t, err)
assert.Equal(t, int64(100), cert1.Leaf.SerialNumber.Int64())
// Write a new cert with a different serial.
certPEM2, keyPEM2 := generateSelfSignedCert(t, 200)
writeCert(t, dir, certPEM2, keyPEM2)
// Manually trigger reload.
w.tryReload()
cert2, err := w.GetCertificate(nil)
require.NoError(t, err)
assert.Equal(t, int64(200), cert2.Leaf.SerialNumber.Int64())
}
func TestTryReloadSkipsUnchanged(t *testing.T) {
dir := t.TempDir()
certPEM, keyPEM := generateSelfSignedCert(t, 42)
writeCert(t, dir, certPEM, keyPEM)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
cert1, err := w.GetCertificate(nil)
require.NoError(t, err)
// Reload with same cert - pointer should remain the same.
w.tryReload()
cert2, err := w.GetCertificate(nil)
require.NoError(t, err)
assert.Same(t, cert1, cert2, "cert pointer should not change when content is the same")
}
func TestWatchDetectsChanges(t *testing.T) {
dir := t.TempDir()
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
writeCert(t, dir, certPEM1, keyPEM1)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
// Use a short poll interval for the test.
w.pollInterval = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go w.Watch(ctx)
// Write new cert.
certPEM2, keyPEM2 := generateSelfSignedCert(t, 999)
writeCert(t, dir, certPEM2, keyPEM2)
// Wait for the watcher to pick it up.
require.Eventually(t, func() bool {
cert, err := w.GetCertificate(nil)
if err != nil {
return false
}
return cert.Leaf.SerialNumber.Int64() == 999
}, 5*time.Second, 50*time.Millisecond, "watcher should detect cert change")
}
func TestIsRelevantFile(t *testing.T) {
assert.True(t, isRelevantFile("tls.crt", "tls.crt", "tls.key"))
assert.True(t, isRelevantFile("tls.key", "tls.crt", "tls.key"))
assert.True(t, isRelevantFile("..data", "tls.crt", "tls.key"))
assert.False(t, isRelevantFile("other.txt", "tls.crt", "tls.key"))
}
// TestWatchSymlinkRotation simulates Kubernetes secret volume updates where
// the data directory is atomically swapped via a ..data symlink.
func TestWatchSymlinkRotation(t *testing.T) {
base := t.TempDir()
// Create initial target directory with certs.
dir1 := filepath.Join(base, "dir1")
require.NoError(t, os.Mkdir(dir1, 0o755))
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
require.NoError(t, os.WriteFile(filepath.Join(dir1, "tls.crt"), certPEM1, 0o600))
require.NoError(t, os.WriteFile(filepath.Join(dir1, "tls.key"), keyPEM1, 0o600))
// Create ..data symlink pointing to dir1.
dataLink := filepath.Join(base, "..data")
require.NoError(t, os.Symlink(dir1, dataLink))
// Create tls.crt and tls.key as symlinks to ..data/{file}.
certLink := filepath.Join(base, "tls.crt")
keyLink := filepath.Join(base, "tls.key")
require.NoError(t, os.Symlink(filepath.Join(dataLink, "tls.crt"), certLink))
require.NoError(t, os.Symlink(filepath.Join(dataLink, "tls.key"), keyLink))
w, err := NewWatcher(certLink, keyLink, nil)
require.NoError(t, err)
cert, err := w.GetCertificate(nil)
require.NoError(t, err)
assert.Equal(t, int64(1), cert.Leaf.SerialNumber.Int64())
w.pollInterval = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go w.Watch(ctx)
// Simulate k8s atomic rotation: create dir2, swap ..data symlink.
dir2 := filepath.Join(base, "dir2")
require.NoError(t, os.Mkdir(dir2, 0o755))
certPEM2, keyPEM2 := generateSelfSignedCert(t, 777)
require.NoError(t, os.WriteFile(filepath.Join(dir2, "tls.crt"), certPEM2, 0o600))
require.NoError(t, os.WriteFile(filepath.Join(dir2, "tls.key"), keyPEM2, 0o600))
// Atomic swap: create temp link, then rename over ..data.
tmpLink := filepath.Join(base, "..data_tmp")
require.NoError(t, os.Symlink(dir2, tmpLink))
require.NoError(t, os.Rename(tmpLink, dataLink))
require.Eventually(t, func() bool {
cert, err := w.GetCertificate(nil)
if err != nil {
return false
}
return cert.Leaf.SerialNumber.Int64() == 777
}, 5*time.Second, 50*time.Millisecond, "watcher should detect symlink rotation")
}
// TestPollLoopDetectsChanges verifies the poll-only fallback path works.
func TestPollLoopDetectsChanges(t *testing.T) {
dir := t.TempDir()
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
writeCert(t, dir, certPEM1, keyPEM1)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
w.pollInterval = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Directly use pollLoop to test the fallback path.
go w.pollLoop(ctx)
certPEM2, keyPEM2 := generateSelfSignedCert(t, 555)
writeCert(t, dir, certPEM2, keyPEM2)
require.Eventually(t, func() bool {
cert, err := w.GetCertificate(nil)
if err != nil {
return false
}
return cert.Leaf.SerialNumber.Int64() == 555
}, 5*time.Second, 50*time.Millisecond, "poll loop should detect cert change")
}
func TestGetCertificateConcurrency(t *testing.T) {
dir := t.TempDir()
certPEM, keyPEM := generateSelfSignedCert(t, 1)
writeCert(t, dir, certPEM, keyPEM)
w, err := NewWatcher(
filepath.Join(dir, "tls.crt"),
filepath.Join(dir, "tls.key"),
nil,
)
require.NoError(t, err)
// Hammer GetCertificate concurrently while reloading.
done := make(chan struct{})
go func() {
for i := 0; i < 100; i++ {
w.tryReload()
}
close(done)
}()
for i := 0; i < 1000; i++ {
cert, err := w.GetCertificate(&tls.ClientHelloInfo{})
assert.NoError(t, err)
assert.NotNil(t, cert)
}
<-done
}

View File

@@ -0,0 +1,20 @@
//go:build !unix
package flock
import (
"context"
"os"
)
// Lock is a no-op on non-Unix platforms. Returns (nil, nil) to indicate
// that no lock was acquired; callers must treat a nil file as "proceed
// without lock" rather than "lock held by someone else."
func Lock(_ context.Context, _ string) (*os.File, error) {
return nil, nil
}
// Unlock is a no-op on non-Unix platforms.
func Unlock(_ *os.File) error {
return nil
}

View File

@@ -0,0 +1,79 @@
//go:build unix
package flock
import (
"context"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLockUnlock(t *testing.T) {
lockPath := filepath.Join(t.TempDir(), "test.lock")
f, err := Lock(context.Background(), lockPath)
require.NoError(t, err)
require.NotNil(t, f)
_, err = os.Stat(lockPath)
assert.NoError(t, err, "lock file should exist")
err = Unlock(f)
assert.NoError(t, err)
}
func TestUnlockNil(t *testing.T) {
err := Unlock(nil)
assert.NoError(t, err, "unlocking nil should be a no-op")
}
func TestLockRespectsContext(t *testing.T) {
lockPath := filepath.Join(t.TempDir(), "test.lock")
f1, err := Lock(context.Background(), lockPath)
require.NoError(t, err)
defer func() { require.NoError(t, Unlock(f1)) }()
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
defer cancel()
_, err = Lock(ctx, lockPath)
assert.ErrorIs(t, err, context.DeadlineExceeded)
}
func TestLockBlocks(t *testing.T) {
lockPath := filepath.Join(t.TempDir(), "test.lock")
f1, err := Lock(context.Background(), lockPath)
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
start := time.Now()
var elapsed time.Duration
go func() {
defer wg.Done()
f2, err := Lock(context.Background(), lockPath)
elapsed = time.Since(start)
assert.NoError(t, err)
if f2 != nil {
assert.NoError(t, Unlock(f2))
}
}()
// Hold the lock for 200ms, then release.
time.Sleep(200 * time.Millisecond)
require.NoError(t, Unlock(f1))
wg.Wait()
assert.GreaterOrEqual(t, elapsed, 150*time.Millisecond,
"Lock should have blocked for at least ~200ms")
}

View File

@@ -0,0 +1,69 @@
//go:build unix
// Package flock provides best-effort advisory file locking using flock(2).
//
// This is used for cross-replica coordination (e.g. preventing duplicate
// ACME requests). Note that flock(2) does NOT work reliably on NFS volumes:
// on NFSv3 it depends on the NLM daemon, on NFSv4 Linux emulates it via
// fcntl locks with different semantics. Callers must treat lock failures
// as non-fatal and proceed without the lock.
package flock
import (
"context"
"errors"
"fmt"
"os"
"syscall"
"time"
log "github.com/sirupsen/logrus"
)
const retryInterval = 100 * time.Millisecond
// Lock acquires an exclusive advisory lock on the given file path.
// It creates the lock file if it does not exist. The lock attempt
// respects context cancellation by using non-blocking flock with polling.
// The caller must call Unlock with the returned *os.File when done.
func Lock(ctx context.Context, path string) (*os.File, error) {
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o600)
if err != nil {
return nil, fmt.Errorf("open lock file %s: %w", path, err)
}
for {
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err == nil {
return f, nil
} else if !errors.Is(err, syscall.EWOULDBLOCK) {
if cerr := f.Close(); cerr != nil {
log.Debugf("close lock file %s: %v", path, cerr)
}
return nil, fmt.Errorf("acquire lock on %s: %w", path, err)
}
select {
case <-ctx.Done():
if cerr := f.Close(); cerr != nil {
log.Debugf("close lock file %s: %v", path, cerr)
}
return nil, ctx.Err()
case <-time.After(retryInterval):
}
}
}
// Unlock releases the lock and closes the file.
func Unlock(f *os.File) error {
if f == nil {
return nil
}
defer f.Close()
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil {
return fmt.Errorf("release lock: %w", err)
}
return nil
}

282
proxy/internal/k8s/lease.go Normal file
View File

@@ -0,0 +1,282 @@
// Package k8s provides a lightweight Kubernetes API client for coordination
// Leases. It uses raw HTTP calls against the mounted service account
// credentials, avoiding a dependency on client-go.
package k8s
import (
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
)
const (
saTokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token"
saNamespacePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
saCACertPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt"
leaseAPIPath = "/apis/coordination.k8s.io/v1"
)
// ErrConflict is returned when a Lease update fails due to a
// resourceVersion mismatch (another writer updated the object first).
var ErrConflict = errors.New("conflict: resource version mismatch")
// Lease represents a coordination.k8s.io/v1 Lease object with only the
// fields needed for distributed locking.
type Lease struct {
APIVersion string `json:"apiVersion"`
Kind string `json:"kind"`
Metadata LeaseMetadata `json:"metadata"`
Spec LeaseSpec `json:"spec"`
}
// LeaseMetadata holds the standard k8s object metadata fields used by Leases.
type LeaseMetadata struct {
Name string `json:"name"`
Namespace string `json:"namespace,omitempty"`
ResourceVersion string `json:"resourceVersion,omitempty"`
Annotations map[string]string `json:"annotations,omitempty"`
}
// LeaseSpec holds the Lease specification fields.
type LeaseSpec struct {
HolderIdentity *string `json:"holderIdentity"`
LeaseDurationSeconds *int32 `json:"leaseDurationSeconds,omitempty"`
AcquireTime *MicroTime `json:"acquireTime"`
RenewTime *MicroTime `json:"renewTime"`
}
// MicroTime wraps time.Time with Kubernetes MicroTime JSON formatting.
type MicroTime struct {
time.Time
}
const microTimeFormat = "2006-01-02T15:04:05.000000Z"
// MarshalJSON implements json.Marshaler with k8s MicroTime format.
func (t MicroTime) MarshalJSON() ([]byte, error) {
return json.Marshal(t.UTC().Format(microTimeFormat))
}
// UnmarshalJSON implements json.Unmarshaler with k8s MicroTime format.
func (t *MicroTime) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
if s == "" {
t.Time = time.Time{}
return nil
}
parsed, err := time.Parse(microTimeFormat, s)
if err != nil {
return fmt.Errorf("parse MicroTime %q: %w", s, err)
}
t.Time = parsed
return nil
}
// LeaseClient talks to the Kubernetes coordination API using raw HTTP.
type LeaseClient struct {
baseURL string
namespace string
httpClient *http.Client
}
// NewLeaseClient creates a client that authenticates via the pod's
// mounted service account. It reads the namespace and CA certificate
// at construction time (they don't rotate) but reads the bearer token
// fresh on each request (tokens rotate).
func NewLeaseClient() (*LeaseClient, error) {
host := os.Getenv("KUBERNETES_SERVICE_HOST")
port := os.Getenv("KUBERNETES_SERVICE_PORT")
if host == "" || port == "" {
return nil, fmt.Errorf("KUBERNETES_SERVICE_HOST/PORT not set")
}
ns, err := os.ReadFile(saNamespacePath)
if err != nil {
return nil, fmt.Errorf("read namespace from %s: %w", saNamespacePath, err)
}
caCert, err := os.ReadFile(saCACertPath)
if err != nil {
return nil, fmt.Errorf("read CA cert from %s: %w", saCACertPath, err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(caCert) {
return nil, fmt.Errorf("parse CA certificate from %s", saCACertPath)
}
return &LeaseClient{
baseURL: fmt.Sprintf("https://%s:%s", host, port),
namespace: strings.TrimSpace(string(ns)),
httpClient: &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: pool,
},
},
},
}, nil
}
// Namespace returns the namespace this client operates in.
func (c *LeaseClient) Namespace() string {
return c.namespace
}
// Get retrieves a Lease by name. Returns (nil, nil) if the Lease does not exist.
func (c *LeaseClient) Get(ctx context.Context, name string) (*Lease, error) {
url := fmt.Sprintf("%s%s/namespaces/%s/leases/%s", c.baseURL, leaseAPIPath, c.namespace, name)
resp, err := c.doRequest(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return nil, nil
}
if resp.StatusCode != http.StatusOK {
return nil, c.readError(resp)
}
var lease Lease
if err := json.NewDecoder(resp.Body).Decode(&lease); err != nil {
return nil, fmt.Errorf("decode lease response: %w", err)
}
return &lease, nil
}
// Create creates a new Lease. Returns the created Lease with server-assigned
// fields like resourceVersion populated.
func (c *LeaseClient) Create(ctx context.Context, lease *Lease) (*Lease, error) {
url := fmt.Sprintf("%s%s/namespaces/%s/leases", c.baseURL, leaseAPIPath, c.namespace)
lease.APIVersion = "coordination.k8s.io/v1"
lease.Kind = "Lease"
if lease.Metadata.Namespace == "" {
lease.Metadata.Namespace = c.namespace
}
resp, err := c.doRequest(ctx, http.MethodPost, url, lease)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusConflict {
return nil, ErrConflict
}
if resp.StatusCode != http.StatusCreated {
return nil, c.readError(resp)
}
var created Lease
if err := json.NewDecoder(resp.Body).Decode(&created); err != nil {
return nil, fmt.Errorf("decode created lease: %w", err)
}
return &created, nil
}
// Update replaces a Lease. The lease.Metadata.ResourceVersion must match
// the current server value (optimistic concurrency). Returns ErrConflict
// on version mismatch.
func (c *LeaseClient) Update(ctx context.Context, lease *Lease) (*Lease, error) {
url := fmt.Sprintf("%s%s/namespaces/%s/leases/%s", c.baseURL, leaseAPIPath, c.namespace, lease.Metadata.Name)
lease.APIVersion = "coordination.k8s.io/v1"
lease.Kind = "Lease"
resp, err := c.doRequest(ctx, http.MethodPut, url, lease)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusConflict {
return nil, ErrConflict
}
if resp.StatusCode != http.StatusOK {
return nil, c.readError(resp)
}
var updated Lease
if err := json.NewDecoder(resp.Body).Decode(&updated); err != nil {
return nil, fmt.Errorf("decode updated lease: %w", err)
}
return &updated, nil
}
func (c *LeaseClient) doRequest(ctx context.Context, method, url string, body any) (*http.Response, error) {
token, err := readToken()
if err != nil {
return nil, fmt.Errorf("read service account token: %w", err)
}
var bodyReader io.Reader
if body != nil {
data, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("marshal request body: %w", err)
}
bodyReader = bytes.NewReader(data)
}
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Accept", "application/json")
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
return c.httpClient.Do(req)
}
func readToken() (string, error) {
data, err := os.ReadFile(saTokenPath)
if err != nil {
return "", fmt.Errorf("read %s: %w", saTokenPath, err)
}
return strings.TrimSpace(string(data)), nil
}
func (c *LeaseClient) readError(resp *http.Response) error {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return fmt.Errorf("k8s API %s %d: %s", resp.Request.URL.Path, resp.StatusCode, string(body))
}
// LeaseNameForDomain returns a deterministic, DNS-label-safe Lease name
// for the given domain. The domain is hashed to avoid dots and length issues.
func LeaseNameForDomain(domain string) string {
h := sha256.Sum256([]byte(domain))
return "cert-lock-" + hex.EncodeToString(h[:8])
}
// InCluster reports whether the process is running inside a Kubernetes pod
// by checking for the KUBERNETES_SERVICE_HOST environment variable.
func InCluster() bool {
_, exists := os.LookupEnv("KUBERNETES_SERVICE_HOST")
return exists
}

View File

@@ -0,0 +1,102 @@
package k8s
import (
"encoding/json"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLeaseNameForDomain(t *testing.T) {
tests := []struct {
domain string
}{
{"example.com"},
{"app.example.com"},
{"another.domain.io"},
}
seen := make(map[string]string)
for _, tc := range tests {
name := LeaseNameForDomain(tc.domain)
assert.True(t, len(name) <= 63, "must be valid DNS label length")
assert.Regexp(t, `^cert-lock-[0-9a-f]{16}$`, name,
"must match expected format for domain %q", tc.domain)
// Same input produces same output.
assert.Equal(t, name, LeaseNameForDomain(tc.domain), "must be deterministic")
// Different domains produce different names.
if prev, ok := seen[name]; ok {
t.Errorf("collision: %q and %q both map to %s", prev, tc.domain, name)
}
seen[name] = tc.domain
}
}
func TestMicroTimeJSON(t *testing.T) {
ts := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC)
mt := &MicroTime{Time: ts}
data, err := json.Marshal(mt)
require.NoError(t, err)
assert.Equal(t, `"2024-06-15T10:30:00.000000Z"`, string(data))
var decoded MicroTime
require.NoError(t, json.Unmarshal(data, &decoded))
assert.True(t, ts.Equal(decoded.Time), "round-trip should preserve time")
}
func TestMicroTimeNullJSON(t *testing.T) {
// Null pointer serializes as JSON null via the Lease struct.
spec := LeaseSpec{
HolderIdentity: nil,
AcquireTime: nil,
RenewTime: nil,
}
data, err := json.Marshal(spec)
require.NoError(t, err)
assert.Contains(t, string(data), `"acquireTime":null`)
assert.Contains(t, string(data), `"renewTime":null`)
}
func TestLeaseJSONRoundTrip(t *testing.T) {
holder := "pod-abc"
dur := int32(300)
now := MicroTime{Time: time.Now().UTC().Truncate(time.Microsecond)}
original := Lease{
APIVersion: "coordination.k8s.io/v1",
Kind: "Lease",
Metadata: LeaseMetadata{
Name: "cert-lock-abcdef0123456789",
Namespace: "default",
ResourceVersion: "12345",
Annotations: map[string]string{
"netbird.io/domain": "app.example.com",
},
},
Spec: LeaseSpec{
HolderIdentity: &holder,
LeaseDurationSeconds: &dur,
AcquireTime: &now,
RenewTime: &now,
},
}
data, err := json.Marshal(original)
require.NoError(t, err)
var decoded Lease
require.NoError(t, json.Unmarshal(data, &decoded))
assert.Equal(t, original.Metadata.Name, decoded.Metadata.Name)
assert.Equal(t, original.Metadata.ResourceVersion, decoded.Metadata.ResourceVersion)
assert.Equal(t, *original.Spec.HolderIdentity, *decoded.Spec.HolderIdentity)
assert.Equal(t, *original.Spec.LeaseDurationSeconds, *decoded.Spec.LeaseDurationSeconds)
assert.True(t, original.Spec.AcquireTime.Equal(decoded.Spec.AcquireTime.Time))
}

View File

@@ -32,6 +32,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/proxy/internal/debug"
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health"
@@ -65,6 +66,8 @@ type Server struct {
ProxyURL string
ManagementAddress string
CertificateDirectory string
CertificateFile string
CertificateKeyFile string
GenerateACMECertificates bool
ACMEChallengeAddress string
ACMEDirectory string
@@ -210,16 +213,17 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
"ServerName": s.ProxyURL,
}).Debug("started ACME challenge server")
} else {
s.Logger.Debug("ACME certificates disabled, using static certificates")
// Otherwise pull some certificates from expected locations.
cert, err := tls.LoadX509KeyPair(
filepath.Join(s.CertificateDirectory, "tls.crt"),
filepath.Join(s.CertificateDirectory, "tls.key"),
)
s.Logger.Debug("ACME certificates disabled, using static certificates with file watching")
certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile)
keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile)
certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger)
if err != nil {
return fmt.Errorf("load provided certificate: %w", err)
return fmt.Errorf("initialize certificate watcher: %w", err)
}
tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
go certWatcher.Watch(ctx)
tlsConfig.GetCertificate = certWatcher.GetCertificate
}
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.