From fd442138e6e63444235592d5eed3e66994983f19 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 9 Feb 2026 20:17:05 +0800 Subject: [PATCH] Add cert hot reload and cert file locking Adds file-watching certificate hot reload, cross-replica ACME certificate lock coordination via flock (Unix) and Kubernetes lease objects. --- proxy/cmd/proxy/cmd/root.go | 6 + proxy/internal/acme/locker.go | 99 ++++++++ proxy/internal/acme/locker_k8s.go | 197 +++++++++++++++ proxy/internal/acme/locker_test.go | 65 +++++ proxy/internal/acme/manager.go | 58 +++-- proxy/internal/acme/manager_test.go | 61 +++++ proxy/internal/certwatch/watcher.go | 279 ++++++++++++++++++++++ proxy/internal/certwatch/watcher_test.go | 292 +++++++++++++++++++++++ proxy/internal/flock/flock_other.go | 20 ++ proxy/internal/flock/flock_test.go | 79 ++++++ proxy/internal/flock/flock_unix.go | 69 ++++++ proxy/internal/k8s/lease.go | 282 ++++++++++++++++++++++ proxy/internal/k8s/lease_test.go | 102 ++++++++ proxy/server.go | 20 +- 14 files changed, 1606 insertions(+), 23 deletions(-) create mode 100644 proxy/internal/acme/locker.go create mode 100644 proxy/internal/acme/locker_k8s.go create mode 100644 proxy/internal/acme/locker_test.go create mode 100644 proxy/internal/acme/manager_test.go create mode 100644 proxy/internal/certwatch/watcher.go create mode 100644 proxy/internal/certwatch/watcher_test.go create mode 100644 proxy/internal/flock/flock_other.go create mode 100644 proxy/internal/flock/flock_test.go create mode 100644 proxy/internal/flock/flock_unix.go create mode 100644 proxy/internal/k8s/lease.go create mode 100644 proxy/internal/k8s/lease_test.go diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 7c0cfb0e3..450927fd8 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -45,6 +45,8 @@ var ( oidcScopes string forwardedProto string trustedProxies string + certFile string + certKeyFile string ) var rootCmd = &cobra.Command{ @@ -74,6 +76,8 @@ func init() { rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated") rootCmd.Flags().StringVar(&forwardedProto, "forwarded-proto", envStringOrDefault("NB_PROXY_FORWARDED_PROTO", "auto"), "X-Forwarded-Proto value for backends: auto, http, or https") rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')") + rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory") + rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory") } // Execute runs the root command. @@ -127,6 +131,8 @@ func runServer(cmd *cobra.Command, args []string) error { ProxyURL: proxyURL, ProxyToken: proxyToken, CertificateDirectory: certDir, + CertificateFile: certFile, + CertificateKeyFile: certKeyFile, GenerateACMECertificates: acmeCerts, ACMEChallengeAddress: acmeAddr, ACMEDirectory: acmeDir, diff --git a/proxy/internal/acme/locker.go b/proxy/internal/acme/locker.go new file mode 100644 index 000000000..11c3abdfe --- /dev/null +++ b/proxy/internal/acme/locker.go @@ -0,0 +1,99 @@ +package acme + +import ( + "context" + "path/filepath" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/flock" + "github.com/netbirdio/netbird/proxy/internal/k8s" +) + +// certLocker provides distributed mutual exclusion for certificate operations. +// Implementations must be safe for concurrent use from multiple goroutines. +type certLocker interface { + // Lock acquires an exclusive lock for the given domain. + // It blocks until the lock is acquired, the context is cancelled, or an + // unrecoverable error occurs. The returned function releases the lock; + // callers must call it exactly once when the critical section is complete. + Lock(ctx context.Context, domain string) (unlock func(), err error) +} + +// CertLockMethod controls how ACME certificate locks are coordinated. +type CertLockMethod string + +const ( + // CertLockAuto detects the environment and selects k8s-lease if running + // in a Kubernetes pod, otherwise flock. + CertLockAuto CertLockMethod = "auto" + // CertLockFlock uses advisory file locks via flock(2). + CertLockFlock CertLockMethod = "flock" + // CertLockK8sLease uses Kubernetes coordination Leases. + CertLockK8sLease CertLockMethod = "k8s-lease" +) + +func newCertLocker(method CertLockMethod, certDir string, logger *log.Logger) certLocker { + if logger == nil { + logger = log.StandardLogger() + } + + if method == "" || method == CertLockAuto { + if k8s.InCluster() { + method = CertLockK8sLease + } else { + method = CertLockFlock + } + logger.Infof("auto-detected cert lock method: %s", method) + } + + switch method { + case CertLockK8sLease: + locker, err := newK8sLeaseLocker(logger) + if err != nil { + logger.Warnf("create k8s lease locker, falling back to flock: %v", err) + return newFlockLocker(certDir, logger) + } + logger.Infof("using k8s lease locker in namespace %s", locker.client.Namespace()) + return locker + default: + logger.Infof("using flock cert locker in %s", certDir) + return newFlockLocker(certDir, logger) + } +} + +type flockLocker struct { + certDir string + logger *log.Logger +} + +func newFlockLocker(certDir string, logger *log.Logger) *flockLocker { + return &flockLocker{certDir: certDir, logger: logger} +} + +// Lock acquires an advisory file lock for the given domain. +func (l *flockLocker) Lock(ctx context.Context, domain string) (func(), error) { + lockPath := filepath.Join(l.certDir, domain+".lock") + lockFile, err := flock.Lock(ctx, lockPath) + if err != nil { + return nil, err + } + + // nil lockFile means locking is not supported (non-unix). + if lockFile == nil { + return func() {}, nil + } + + return func() { + if err := flock.Unlock(lockFile); err != nil { + l.logger.Debugf("release cert lock for domain %q: %v", domain, err) + } + }, nil +} + +type noopLocker struct{} + +// Lock is a no-op that always succeeds immediately. +func (noopLocker) Lock(context.Context, string) (func(), error) { + return func() {}, nil +} diff --git a/proxy/internal/acme/locker_k8s.go b/proxy/internal/acme/locker_k8s.go new file mode 100644 index 000000000..a3f8043e6 --- /dev/null +++ b/proxy/internal/acme/locker_k8s.go @@ -0,0 +1,197 @@ +package acme + +import ( + "context" + "errors" + "fmt" + "os" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/proxy/internal/k8s" +) + +const ( + // leaseDurationSec is the Kubernetes Lease TTL. If the holder crashes without + // releasing the lock, other replicas must wait this long before taking over. + // This is intentionally generous: in the worst case two replicas may both + // issue an ACME request for the same domain, which is harmless (the CA + // deduplicates and the cache converges). + leaseDurationSec = 300 + retryBaseBackoff = 500 * time.Millisecond + retryMaxBackoff = 10 * time.Second +) + +type k8sLeaseLocker struct { + client *k8s.LeaseClient + identity string + logger *log.Logger +} + +func newK8sLeaseLocker(logger *log.Logger) (*k8sLeaseLocker, error) { + client, err := k8s.NewLeaseClient() + if err != nil { + return nil, fmt.Errorf("create k8s lease client: %w", err) + } + + identity, err := os.Hostname() + if err != nil { + return nil, fmt.Errorf("get hostname: %w", err) + } + + return &k8sLeaseLocker{ + client: client, + identity: identity, + logger: logger, + }, nil +} + +// Lock acquires a Kubernetes Lease for the given domain using optimistic +// concurrency. It retries with exponential backoff until the lease is +// acquired or the context is cancelled. +func (l *k8sLeaseLocker) Lock(ctx context.Context, domain string) (func(), error) { + leaseName := k8s.LeaseNameForDomain(domain) + backoff := retryBaseBackoff + + for { + acquired, err := l.tryAcquire(ctx, leaseName, domain) + if err != nil { + return nil, fmt.Errorf("acquire lease %s for %q: %w", leaseName, domain, err) + } + if acquired { + l.logger.Debugf("k8s lease %s acquired for domain %q", leaseName, domain) + return l.unlockFunc(leaseName, domain), nil + } + + l.logger.Debugf("k8s lease %s held by another replica, retrying in %s", leaseName, backoff) + + timer := time.NewTimer(backoff) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-timer.C: + } + + backoff *= 2 + if backoff > retryMaxBackoff { + backoff = retryMaxBackoff + } + } +} + +// tryAcquire attempts to create or take over a Lease. Returns (true, nil) +// on success, (false, nil) if the lease is held and not stale, or an error. +func (l *k8sLeaseLocker) tryAcquire(ctx context.Context, name, domain string) (bool, error) { + existing, err := l.client.Get(ctx, name) + if err != nil { + return false, err + } + + now := k8s.MicroTime{Time: time.Now().UTC()} + dur := int32(leaseDurationSec) + + if existing == nil { + lease := &k8s.Lease{ + Metadata: k8s.LeaseMetadata{ + Name: name, + Annotations: map[string]string{ + "netbird.io/domain": domain, + }, + }, + Spec: k8s.LeaseSpec{ + HolderIdentity: &l.identity, + LeaseDurationSeconds: &dur, + AcquireTime: &now, + RenewTime: &now, + }, + } + + if _, err := l.client.Create(ctx, lease); errors.Is(err, k8s.ErrConflict) { + return false, nil + } else if err != nil { + return false, err + } + return true, nil + } + + if !l.canTakeover(existing) { + return false, nil + } + + existing.Spec.HolderIdentity = &l.identity + existing.Spec.LeaseDurationSeconds = &dur + existing.Spec.AcquireTime = &now + existing.Spec.RenewTime = &now + + if _, err := l.client.Update(ctx, existing); errors.Is(err, k8s.ErrConflict) { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +} + +// canTakeover returns true if the lease is free (no holder) or stale +// (renewTime + leaseDuration has passed). +func (l *k8sLeaseLocker) canTakeover(lease *k8s.Lease) bool { + holder := lease.Spec.HolderIdentity + if holder == nil || *holder == "" { + return true + } + + // We already hold it (e.g. from a previous crashed attempt). + if *holder == l.identity { + return true + } + + if lease.Spec.RenewTime == nil || lease.Spec.LeaseDurationSeconds == nil { + return true + } + + expiry := lease.Spec.RenewTime.Add(time.Duration(*lease.Spec.LeaseDurationSeconds) * time.Second) + if time.Now().After(expiry) { + l.logger.Infof("k8s lease %s held by %q is stale (expired %s ago), taking over", + lease.Metadata.Name, *holder, time.Since(expiry).Round(time.Second)) + return true + } + + return false +} + +// unlockFunc returns a closure that releases the lease by clearing the holder. +func (l *k8sLeaseLocker) unlockFunc(name, domain string) func() { + return func() { + // Use a fresh context: the parent may already be cancelled. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Re-GET to get current resourceVersion (ours may be stale if + // the lock was held for a long time and something updated it). + current, err := l.client.Get(ctx, name) + if err != nil { + l.logger.Debugf("release k8s lease %s for %q: get: %v", name, domain, err) + return + } + if current == nil { + return + } + + // Only clear if we're still the holder. + if current.Spec.HolderIdentity == nil || *current.Spec.HolderIdentity != l.identity { + l.logger.Debugf("k8s lease %s for %q: holder changed to %v, skip release", + name, domain, current.Spec.HolderIdentity) + return + } + + empty := "" + current.Spec.HolderIdentity = &empty + current.Spec.AcquireTime = nil + current.Spec.RenewTime = nil + + if _, err := l.client.Update(ctx, current); err != nil { + l.logger.Debugf("release k8s lease %s for %q: update: %v", name, domain, err) + } + } +} diff --git a/proxy/internal/acme/locker_test.go b/proxy/internal/acme/locker_test.go new file mode 100644 index 000000000..39245df0c --- /dev/null +++ b/proxy/internal/acme/locker_test.go @@ -0,0 +1,65 @@ +package acme + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFlockLockerRoundTrip(t *testing.T) { + dir := t.TempDir() + locker := newFlockLocker(dir, nil) + + unlock, err := locker.Lock(context.Background(), "example.com") + require.NoError(t, err) + require.NotNil(t, unlock) + + // Lock file should exist. + assert.FileExists(t, filepath.Join(dir, "example.com.lock")) + + unlock() +} + +func TestNoopLocker(t *testing.T) { + locker := noopLocker{} + unlock, err := locker.Lock(context.Background(), "example.com") + require.NoError(t, err) + require.NotNil(t, unlock) + unlock() +} + +func TestNewCertLockerDefaultsToFlock(t *testing.T) { + dir := t.TempDir() + + // t.Setenv registers cleanup to restore the original value. + // os.Unsetenv is needed because the production code uses LookupEnv, + // which distinguishes "empty" from "not set". + t.Setenv("KUBERNETES_SERVICE_HOST", "") + os.Unsetenv("KUBERNETES_SERVICE_HOST") + locker := newCertLocker(CertLockAuto, dir, nil) + + _, ok := locker.(*flockLocker) + assert.True(t, ok, "auto without k8s env should select flockLocker") +} + +func TestNewCertLockerExplicitFlock(t *testing.T) { + dir := t.TempDir() + locker := newCertLocker(CertLockFlock, dir, nil) + + _, ok := locker.(*flockLocker) + assert.True(t, ok, "explicit flock should select flockLocker") +} + +func TestNewCertLockerK8sFallsBackToFlock(t *testing.T) { + dir := t.TempDir() + + // k8s-lease without SA files should fall back to flock. + locker := newCertLocker(CertLockK8sLease, dir, nil) + + _, ok := locker.(*flockLocker) + assert.True(t, ok, "k8s-lease without SA should fall back to flockLocker") +} diff --git a/proxy/internal/acme/manager.go b/proxy/internal/acme/manager.go index d48306833..e3e2bb5fc 100644 --- a/proxy/internal/acme/manager.go +++ b/proxy/internal/acme/manager.go @@ -23,9 +23,14 @@ type certificateNotifier interface { NotifyCertificateIssued(ctx context.Context, accountID, reverseProxyID, domain string) error } +// Manager wraps autocert.Manager with domain tracking and cross-replica +// coordination via a pluggable locking strategy. The locker prevents +// duplicate ACME requests when multiple replicas share a certificate cache. type Manager struct { *autocert.Manager + certDir string + locker certLocker domainsMux sync.RWMutex domains map[string]struct { accountID string @@ -36,11 +41,16 @@ type Manager struct { logger *log.Logger } -func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger) *Manager { +// NewManager creates a new ACME certificate manager. The certDir is used +// for caching certificates. The lockMethod controls cross-replica +// coordination strategy (see CertLockMethod constants). +func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager { if logger == nil { logger = log.StandardLogger() } mgr := &Manager{ + certDir: certDir, + locker: newCertLocker(lockMethod, certDir, logger), domains: make(map[string]struct { accountID string reverseProxyID string @@ -59,12 +69,15 @@ func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *l return mgr } -func (mgr *Manager) hostPolicy(ctx context.Context, domain string) error { +func (mgr *Manager) hostPolicy(_ context.Context, host string) error { + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } mgr.domainsMux.RLock() - _, exists := mgr.domains[domain] + _, exists := mgr.domains[host] mgr.domainsMux.RUnlock() if !exists { - return fmt.Errorf("unknown domain %q", domain) + return fmt.Errorf("unknown domain %q", host) } return nil } @@ -84,17 +97,31 @@ func (mgr *Manager) AddDomain(domain, accountID, reverseProxyID string) { } // prefetchCertificate proactively triggers certificate generation for a domain. +// It acquires a distributed lock to prevent multiple replicas from issuing +// duplicate ACME requests. The second replica will block until the first +// finishes, then find the certificate in the cache. func (mgr *Manager) prefetchCertificate(domain string) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) defer cancel() + mgr.logger.Infof("acquiring cert lock for domain %q", domain) + lockStart := time.Now() + unlock, err := mgr.locker.Lock(ctx, domain) + if err != nil { + mgr.logger.Warnf("acquire cert lock for domain %q, proceeding without lock: %v", domain, err) + } else { + mgr.logger.Infof("acquired cert lock for domain %q in %s", domain, time.Since(lockStart)) + defer unlock() + } + hello := &tls.ClientHelloInfo{ ServerName: domain, Conn: &dummyConn{ctx: ctx}, } - mgr.logger.Infof("prefetching certificate for domain %q", domain) + start := time.Now() cert, err := mgr.GetCertificate(hello) + elapsed := time.Since(start) if err != nil { mgr.logger.Warnf("prefetch certificate for domain %q: %v", domain, err) return @@ -102,11 +129,18 @@ func (mgr *Manager) prefetchCertificate(domain string) { now := time.Now() if cert != nil && cert.Leaf != nil { - mgr.logCertificateDetails(domain, cert.Leaf, now) + leaf := cert.Leaf + mgr.logger.Infof("certificate for domain %q ready in %s: serial=%s SANs=%v notAfter=%s", + domain, elapsed.Round(time.Millisecond), + leaf.SerialNumber.Text(16), + leaf.DNSNames, + leaf.NotAfter.UTC().Format(time.RFC3339), + ) + mgr.logCertificateDetails(domain, leaf, now) + } else { + mgr.logger.Infof("certificate for domain %q ready in %s", domain, elapsed.Round(time.Millisecond)) } - mgr.logger.Infof("certificate for domain %q is ready", domain) - mgr.domainsMux.RLock() info, exists := mgr.domains[domain] mgr.domainsMux.RUnlock() @@ -120,18 +154,12 @@ func (mgr *Manager) prefetchCertificate(domain string) { // logCertificateDetails logs certificate validity and SCT timestamps. func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate, now time.Time) { - mgr.logger.Infof("certificate for %q: NotBefore=%v, NotAfter=%v, now=%v", - domain, cert.NotBefore.UTC(), cert.NotAfter.UTC(), now.UTC()) - if cert.NotBefore.After(now) { mgr.logger.Warnf("certificate for %q NotBefore is in the future by %v", domain, cert.NotBefore.Sub(now)) - } else { - mgr.logger.Infof("certificate for %q NotBefore is %v in the past", domain, now.Sub(cert.NotBefore)) } sctTimestamps := mgr.parseSCTTimestamps(cert) if len(sctTimestamps) == 0 { - mgr.logger.Warnf("certificate for %q has no embedded SCTs", domain) return } @@ -140,7 +168,7 @@ func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate, mgr.logger.Warnf("certificate for %q SCT[%d] timestamp is in the future: %v (by %v)", domain, i, sctTime.UTC(), sctTime.Sub(now)) } else { - mgr.logger.Infof("certificate for %q SCT[%d] timestamp: %v (%v in the past)", + mgr.logger.Debugf("certificate for %q SCT[%d] timestamp: %v (%v in the past)", domain, i, sctTime.UTC(), now.Sub(sctTime)) } } diff --git a/proxy/internal/acme/manager_test.go b/proxy/internal/acme/manager_test.go new file mode 100644 index 000000000..273d5d95a --- /dev/null +++ b/proxy/internal/acme/manager_test.go @@ -0,0 +1,61 @@ +package acme + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHostPolicy(t *testing.T) { + mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "") + mgr.AddDomain("example.com", "acc1", "rp1") + + tests := []struct { + name string + host string + wantErr bool + }{ + { + name: "exact domain match", + host: "example.com", + }, + { + name: "domain with port", + host: "example.com:443", + }, + { + name: "unknown domain", + host: "unknown.com", + wantErr: true, + }, + { + name: "unknown domain with port", + host: "unknown.com:443", + wantErr: true, + }, + { + name: "empty host", + host: "", + wantErr: true, + }, + { + name: "port only", + host: ":443", + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := mgr.hostPolicy(context.Background(), tc.host) + if tc.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown domain") + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/proxy/internal/certwatch/watcher.go b/proxy/internal/certwatch/watcher.go new file mode 100644 index 000000000..78ad1ab7c --- /dev/null +++ b/proxy/internal/certwatch/watcher.go @@ -0,0 +1,279 @@ +// Package certwatch watches TLS certificate files on disk and provides +// a hot-reloading GetCertificate callback for tls.Config. +package certwatch + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "path/filepath" + "sync" + "time" + + "github.com/fsnotify/fsnotify" + log "github.com/sirupsen/logrus" +) + +const ( + defaultPollInterval = 30 * time.Second + debounceDelay = 500 * time.Millisecond +) + +// Watcher monitors TLS certificate files on disk and caches the loaded +// certificate in memory. It detects changes via fsnotify (with a polling +// fallback for filesystems like NFS that lack inotify support) and +// reloads the certificate pair automatically. +type Watcher struct { + certPath string + keyPath string + + mu sync.RWMutex + cert *tls.Certificate + leaf *x509.Certificate + + pollInterval time.Duration + logger *log.Logger +} + +// NewWatcher creates a Watcher that monitors the given cert and key files. +// It performs an initial load of the certificate and returns an error +// if the initial load fails. +func NewWatcher(certPath, keyPath string, logger *log.Logger) (*Watcher, error) { + if logger == nil { + logger = log.StandardLogger() + } + + w := &Watcher{ + certPath: certPath, + keyPath: keyPath, + pollInterval: defaultPollInterval, + logger: logger, + } + + if err := w.reload(); err != nil { + return nil, fmt.Errorf("initial certificate load: %w", err) + } + + return w, nil +} + +// GetCertificate returns the current in-memory certificate. +// It is safe for concurrent use and compatible with tls.Config.GetCertificate. +func (w *Watcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + w.mu.RLock() + defer w.mu.RUnlock() + + return w.cert, nil +} + +// Watch starts watching for certificate file changes. It blocks until +// ctx is cancelled. It uses fsnotify for immediate detection and falls +// back to polling if fsnotify is unavailable (e.g. on NFS). +// Even with fsnotify active, a periodic poll runs as a safety net. +func (w *Watcher) Watch(ctx context.Context) { + // Watch the parent directory rather than individual files. Some volume + // mounts use an atomic symlink swap (..data -> timestamped dir), so + // watching the parent directory catches the link replacement. + certDir := filepath.Dir(w.certPath) + keyDir := filepath.Dir(w.keyPath) + + watcher, err := fsnotify.NewWatcher() + if err != nil { + w.logger.Warnf("fsnotify unavailable, using polling only: %v", err) + w.pollLoop(ctx) + return + } + defer func() { + if err := watcher.Close(); err != nil { + w.logger.Debugf("close fsnotify watcher: %v", err) + } + }() + + if err := watcher.Add(certDir); err != nil { + w.logger.Warnf("fsnotify watch on %s failed, using polling only: %v", certDir, err) + w.pollLoop(ctx) + return + } + + if keyDir != certDir { + if err := watcher.Add(keyDir); err != nil { + w.logger.Warnf("fsnotify watch on %s failed: %v", keyDir, err) + } + } + + w.logger.Infof("watching certificate files in %s", certDir) + w.fsnotifyLoop(ctx, watcher) +} + +func (w *Watcher) fsnotifyLoop(ctx context.Context, watcher *fsnotify.Watcher) { + certBase := filepath.Base(w.certPath) + keyBase := filepath.Base(w.keyPath) + + var debounce *time.Timer + defer func() { + if debounce != nil { + debounce.Stop() + } + }() + + // Periodic poll as a safety net for missed fsnotify events. + pollTicker := time.NewTicker(w.pollInterval) + defer pollTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + + case event, ok := <-watcher.Events: + if !ok { + return + } + + base := filepath.Base(event.Name) + if !isRelevantFile(base, certBase, keyBase) { + w.logger.Debugf("fsnotify: ignoring event %s on %s", event.Op, event.Name) + continue + } + if !event.Has(fsnotify.Create) && !event.Has(fsnotify.Write) && !event.Has(fsnotify.Rename) { + w.logger.Debugf("fsnotify: ignoring op %s on %s", event.Op, base) + continue + } + + w.logger.Debugf("fsnotify: detected %s on %s, scheduling reload", event.Op, base) + + // Debounce: cert-manager may write cert and key as separate + // operations. Wait briefly to load both at once. + if debounce != nil { + debounce.Stop() + } + debounce = time.AfterFunc(debounceDelay, func() { + if ctx.Err() != nil { + return + } + w.tryReload() + }) + + case err, ok := <-watcher.Errors: + if !ok { + return + } + w.logger.Warnf("fsnotify error: %v", err) + + case <-pollTicker.C: + w.tryReload() + } + } +} + +func (w *Watcher) pollLoop(ctx context.Context) { + ticker := time.NewTicker(w.pollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + w.tryReload() + } + } +} + +// reload loads the certificate from disk and updates the in-memory cache. +func (w *Watcher) reload() error { + cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath) + if err != nil { + return err + } + + // Parse the leaf for comparison on subsequent reloads. + if cert.Leaf == nil && len(cert.Certificate) > 0 { + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return fmt.Errorf("parse leaf certificate: %w", err) + } + cert.Leaf = leaf + } + + w.mu.Lock() + w.cert = &cert + w.leaf = cert.Leaf + w.mu.Unlock() + + w.logCertDetails("loaded certificate", cert.Leaf) + + return nil +} + +// tryReload attempts to reload the certificate. It skips the update +// if the certificate on disk is identical to the one in memory (same +// serial number and issuer) to avoid redundant log noise. +func (w *Watcher) tryReload() { + cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath) + if err != nil { + w.logger.Warnf("reload certificate: %v", err) + return + } + + if cert.Leaf == nil && len(cert.Certificate) > 0 { + leaf, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + w.logger.Warnf("parse reloaded leaf certificate: %v", err) + return + } + cert.Leaf = leaf + } + + w.mu.Lock() + + if w.leaf != nil && cert.Leaf != nil && + w.leaf.SerialNumber.Cmp(cert.Leaf.SerialNumber) == 0 && + w.leaf.Issuer.CommonName == cert.Leaf.Issuer.CommonName { + w.mu.Unlock() + return + } + + prev := w.leaf + w.cert = &cert + w.leaf = cert.Leaf + w.mu.Unlock() + + w.logCertChange(prev, cert.Leaf) +} + +func (w *Watcher) logCertDetails(msg string, leaf *x509.Certificate) { + if leaf == nil { + w.logger.Info(msg) + return + } + + w.logger.Infof("%s: subject=%q serial=%s SANs=%v notAfter=%s", + msg, + leaf.Subject.CommonName, + leaf.SerialNumber.Text(16), + leaf.DNSNames, + leaf.NotAfter.UTC().Format(time.RFC3339), + ) +} + +func (w *Watcher) logCertChange(prev, next *x509.Certificate) { + if prev == nil || next == nil { + w.logCertDetails("certificate reloaded from disk", next) + return + } + + w.logger.Infof("certificate reloaded from disk: subject=%q -> %q serial=%s -> %s notAfter=%s -> %s", + prev.Subject.CommonName, next.Subject.CommonName, + prev.SerialNumber.Text(16), next.SerialNumber.Text(16), + prev.NotAfter.UTC().Format(time.RFC3339), next.NotAfter.UTC().Format(time.RFC3339), + ) +} + +// isRelevantFile returns true if the changed file name is one we care about. +// This includes the cert/key files themselves and the ..data symlink used +// by atomic volume mounts. +func isRelevantFile(changed, certBase, keyBase string) bool { + return changed == certBase || changed == keyBase || changed == "..data" +} diff --git a/proxy/internal/certwatch/watcher_test.go b/proxy/internal/certwatch/watcher_test.go new file mode 100644 index 000000000..06b0a4bb8 --- /dev/null +++ b/proxy/internal/certwatch/watcher_test.go @@ -0,0 +1,292 @@ +package certwatch + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func generateSelfSignedCert(t *testing.T, serial int64) (certPEM, keyPEM []byte) { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(serial), + Subject: pkix.Name{CommonName: "test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(24 * time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + keyDER, err := x509.MarshalECPrivateKey(key) + require.NoError(t, err) + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + + return certPEM, keyPEM +} + +func writeCert(t *testing.T, dir string, certPEM, keyPEM []byte) { + t.Helper() + + require.NoError(t, os.WriteFile(filepath.Join(dir, "tls.crt"), certPEM, 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "tls.key"), keyPEM, 0o600)) +} + +func TestNewWatcher(t *testing.T) { + dir := t.TempDir() + certPEM, keyPEM := generateSelfSignedCert(t, 1) + writeCert(t, dir, certPEM, keyPEM) + + w, err := NewWatcher( + filepath.Join(dir, "tls.crt"), + filepath.Join(dir, "tls.key"), + nil, + ) + require.NoError(t, err) + + cert, err := w.GetCertificate(nil) + require.NoError(t, err) + require.NotNil(t, cert) + assert.Equal(t, int64(1), cert.Leaf.SerialNumber.Int64()) +} + +func TestNewWatcherMissingFiles(t *testing.T) { + dir := t.TempDir() + + _, err := NewWatcher( + filepath.Join(dir, "tls.crt"), + filepath.Join(dir, "tls.key"), + nil, + ) + assert.Error(t, err) +} + +func TestReload(t *testing.T) { + dir := t.TempDir() + certPEM1, keyPEM1 := generateSelfSignedCert(t, 100) + writeCert(t, dir, certPEM1, keyPEM1) + + w, err := NewWatcher( + filepath.Join(dir, "tls.crt"), + filepath.Join(dir, "tls.key"), + nil, + ) + require.NoError(t, err) + + cert1, err := w.GetCertificate(nil) + require.NoError(t, err) + assert.Equal(t, int64(100), cert1.Leaf.SerialNumber.Int64()) + + // Write a new cert with a different serial. + certPEM2, keyPEM2 := generateSelfSignedCert(t, 200) + writeCert(t, dir, certPEM2, keyPEM2) + + // Manually trigger reload. + w.tryReload() + + cert2, err := w.GetCertificate(nil) + require.NoError(t, err) + assert.Equal(t, int64(200), cert2.Leaf.SerialNumber.Int64()) +} + +func TestTryReloadSkipsUnchanged(t *testing.T) { + dir := t.TempDir() + certPEM, keyPEM := generateSelfSignedCert(t, 42) + writeCert(t, dir, certPEM, keyPEM) + + w, err := NewWatcher( + filepath.Join(dir, "tls.crt"), + filepath.Join(dir, "tls.key"), + nil, + ) + require.NoError(t, err) + + cert1, err := w.GetCertificate(nil) + require.NoError(t, err) + + // Reload with same cert - pointer should remain the same. + w.tryReload() + + cert2, err := w.GetCertificate(nil) + require.NoError(t, err) + assert.Same(t, cert1, cert2, "cert pointer should not change when content is the same") +} + +func TestWatchDetectsChanges(t *testing.T) { + dir := t.TempDir() + certPEM1, keyPEM1 := generateSelfSignedCert(t, 1) + writeCert(t, dir, certPEM1, keyPEM1) + + w, err := NewWatcher( + filepath.Join(dir, "tls.crt"), + filepath.Join(dir, "tls.key"), + nil, + ) + require.NoError(t, err) + + // Use a short poll interval for the test. + w.pollInterval = 100 * time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go w.Watch(ctx) + + // Write new cert. + certPEM2, keyPEM2 := generateSelfSignedCert(t, 999) + writeCert(t, dir, certPEM2, keyPEM2) + + // Wait for the watcher to pick it up. + require.Eventually(t, func() bool { + cert, err := w.GetCertificate(nil) + if err != nil { + return false + } + return cert.Leaf.SerialNumber.Int64() == 999 + }, 5*time.Second, 50*time.Millisecond, "watcher should detect cert change") +} + +func TestIsRelevantFile(t *testing.T) { + assert.True(t, isRelevantFile("tls.crt", "tls.crt", "tls.key")) + assert.True(t, isRelevantFile("tls.key", "tls.crt", "tls.key")) + assert.True(t, isRelevantFile("..data", "tls.crt", "tls.key")) + assert.False(t, isRelevantFile("other.txt", "tls.crt", "tls.key")) +} + +// TestWatchSymlinkRotation simulates Kubernetes secret volume updates where +// the data directory is atomically swapped via a ..data symlink. +func TestWatchSymlinkRotation(t *testing.T) { + base := t.TempDir() + + // Create initial target directory with certs. + dir1 := filepath.Join(base, "dir1") + require.NoError(t, os.Mkdir(dir1, 0o755)) + certPEM1, keyPEM1 := generateSelfSignedCert(t, 1) + require.NoError(t, os.WriteFile(filepath.Join(dir1, "tls.crt"), certPEM1, 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(dir1, "tls.key"), keyPEM1, 0o600)) + + // Create ..data symlink pointing to dir1. + dataLink := filepath.Join(base, "..data") + require.NoError(t, os.Symlink(dir1, dataLink)) + + // Create tls.crt and tls.key as symlinks to ..data/{file}. + certLink := filepath.Join(base, "tls.crt") + keyLink := filepath.Join(base, "tls.key") + require.NoError(t, os.Symlink(filepath.Join(dataLink, "tls.crt"), certLink)) + require.NoError(t, os.Symlink(filepath.Join(dataLink, "tls.key"), keyLink)) + + w, err := NewWatcher(certLink, keyLink, nil) + require.NoError(t, err) + + cert, err := w.GetCertificate(nil) + require.NoError(t, err) + assert.Equal(t, int64(1), cert.Leaf.SerialNumber.Int64()) + + w.pollInterval = 100 * time.Millisecond + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go w.Watch(ctx) + + // Simulate k8s atomic rotation: create dir2, swap ..data symlink. + dir2 := filepath.Join(base, "dir2") + require.NoError(t, os.Mkdir(dir2, 0o755)) + certPEM2, keyPEM2 := generateSelfSignedCert(t, 777) + require.NoError(t, os.WriteFile(filepath.Join(dir2, "tls.crt"), certPEM2, 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(dir2, "tls.key"), keyPEM2, 0o600)) + + // Atomic swap: create temp link, then rename over ..data. + tmpLink := filepath.Join(base, "..data_tmp") + require.NoError(t, os.Symlink(dir2, tmpLink)) + require.NoError(t, os.Rename(tmpLink, dataLink)) + + require.Eventually(t, func() bool { + cert, err := w.GetCertificate(nil) + if err != nil { + return false + } + return cert.Leaf.SerialNumber.Int64() == 777 + }, 5*time.Second, 50*time.Millisecond, "watcher should detect symlink rotation") +} + +// TestPollLoopDetectsChanges verifies the poll-only fallback path works. +func TestPollLoopDetectsChanges(t *testing.T) { + dir := t.TempDir() + certPEM1, keyPEM1 := generateSelfSignedCert(t, 1) + writeCert(t, dir, certPEM1, keyPEM1) + + w, err := NewWatcher( + filepath.Join(dir, "tls.crt"), + filepath.Join(dir, "tls.key"), + nil, + ) + require.NoError(t, err) + + w.pollInterval = 100 * time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Directly use pollLoop to test the fallback path. + go w.pollLoop(ctx) + + certPEM2, keyPEM2 := generateSelfSignedCert(t, 555) + writeCert(t, dir, certPEM2, keyPEM2) + + require.Eventually(t, func() bool { + cert, err := w.GetCertificate(nil) + if err != nil { + return false + } + return cert.Leaf.SerialNumber.Int64() == 555 + }, 5*time.Second, 50*time.Millisecond, "poll loop should detect cert change") +} + +func TestGetCertificateConcurrency(t *testing.T) { + dir := t.TempDir() + certPEM, keyPEM := generateSelfSignedCert(t, 1) + writeCert(t, dir, certPEM, keyPEM) + + w, err := NewWatcher( + filepath.Join(dir, "tls.crt"), + filepath.Join(dir, "tls.key"), + nil, + ) + require.NoError(t, err) + + // Hammer GetCertificate concurrently while reloading. + done := make(chan struct{}) + go func() { + for i := 0; i < 100; i++ { + w.tryReload() + } + close(done) + }() + + for i := 0; i < 1000; i++ { + cert, err := w.GetCertificate(&tls.ClientHelloInfo{}) + assert.NoError(t, err) + assert.NotNil(t, cert) + } + + <-done +} diff --git a/proxy/internal/flock/flock_other.go b/proxy/internal/flock/flock_other.go new file mode 100644 index 000000000..c73e1e217 --- /dev/null +++ b/proxy/internal/flock/flock_other.go @@ -0,0 +1,20 @@ +//go:build !unix + +package flock + +import ( + "context" + "os" +) + +// Lock is a no-op on non-Unix platforms. Returns (nil, nil) to indicate +// that no lock was acquired; callers must treat a nil file as "proceed +// without lock" rather than "lock held by someone else." +func Lock(_ context.Context, _ string) (*os.File, error) { + return nil, nil +} + +// Unlock is a no-op on non-Unix platforms. +func Unlock(_ *os.File) error { + return nil +} diff --git a/proxy/internal/flock/flock_test.go b/proxy/internal/flock/flock_test.go new file mode 100644 index 000000000..501a173f7 --- /dev/null +++ b/proxy/internal/flock/flock_test.go @@ -0,0 +1,79 @@ +//go:build unix + +package flock + +import ( + "context" + "os" + "path/filepath" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLockUnlock(t *testing.T) { + lockPath := filepath.Join(t.TempDir(), "test.lock") + + f, err := Lock(context.Background(), lockPath) + require.NoError(t, err) + require.NotNil(t, f) + + _, err = os.Stat(lockPath) + assert.NoError(t, err, "lock file should exist") + + err = Unlock(f) + assert.NoError(t, err) +} + +func TestUnlockNil(t *testing.T) { + err := Unlock(nil) + assert.NoError(t, err, "unlocking nil should be a no-op") +} + +func TestLockRespectsContext(t *testing.T) { + lockPath := filepath.Join(t.TempDir(), "test.lock") + + f1, err := Lock(context.Background(), lockPath) + require.NoError(t, err) + defer func() { require.NoError(t, Unlock(f1)) }() + + ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer cancel() + + _, err = Lock(ctx, lockPath) + assert.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestLockBlocks(t *testing.T) { + lockPath := filepath.Join(t.TempDir(), "test.lock") + + f1, err := Lock(context.Background(), lockPath) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + start := time.Now() + var elapsed time.Duration + + go func() { + defer wg.Done() + f2, err := Lock(context.Background(), lockPath) + elapsed = time.Since(start) + assert.NoError(t, err) + if f2 != nil { + assert.NoError(t, Unlock(f2)) + } + }() + + // Hold the lock for 200ms, then release. + time.Sleep(200 * time.Millisecond) + require.NoError(t, Unlock(f1)) + + wg.Wait() + assert.GreaterOrEqual(t, elapsed, 150*time.Millisecond, + "Lock should have blocked for at least ~200ms") +} diff --git a/proxy/internal/flock/flock_unix.go b/proxy/internal/flock/flock_unix.go new file mode 100644 index 000000000..84a77e1dc --- /dev/null +++ b/proxy/internal/flock/flock_unix.go @@ -0,0 +1,69 @@ +//go:build unix + +// Package flock provides best-effort advisory file locking using flock(2). +// +// This is used for cross-replica coordination (e.g. preventing duplicate +// ACME requests). Note that flock(2) does NOT work reliably on NFS volumes: +// on NFSv3 it depends on the NLM daemon, on NFSv4 Linux emulates it via +// fcntl locks with different semantics. Callers must treat lock failures +// as non-fatal and proceed without the lock. +package flock + +import ( + "context" + "errors" + "fmt" + "os" + "syscall" + "time" + + log "github.com/sirupsen/logrus" +) + +const retryInterval = 100 * time.Millisecond + +// Lock acquires an exclusive advisory lock on the given file path. +// It creates the lock file if it does not exist. The lock attempt +// respects context cancellation by using non-blocking flock with polling. +// The caller must call Unlock with the returned *os.File when done. +func Lock(ctx context.Context, path string) (*os.File, error) { + f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o600) + if err != nil { + return nil, fmt.Errorf("open lock file %s: %w", path, err) + } + + for { + if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err == nil { + return f, nil + } else if !errors.Is(err, syscall.EWOULDBLOCK) { + if cerr := f.Close(); cerr != nil { + log.Debugf("close lock file %s: %v", path, cerr) + } + return nil, fmt.Errorf("acquire lock on %s: %w", path, err) + } + + select { + case <-ctx.Done(): + if cerr := f.Close(); cerr != nil { + log.Debugf("close lock file %s: %v", path, cerr) + } + return nil, ctx.Err() + case <-time.After(retryInterval): + } + } +} + +// Unlock releases the lock and closes the file. +func Unlock(f *os.File) error { + if f == nil { + return nil + } + + defer f.Close() + + if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil { + return fmt.Errorf("release lock: %w", err) + } + + return nil +} diff --git a/proxy/internal/k8s/lease.go b/proxy/internal/k8s/lease.go new file mode 100644 index 000000000..cef15521a --- /dev/null +++ b/proxy/internal/k8s/lease.go @@ -0,0 +1,282 @@ +// Package k8s provides a lightweight Kubernetes API client for coordination +// Leases. It uses raw HTTP calls against the mounted service account +// credentials, avoiding a dependency on client-go. +package k8s + +import ( + "bytes" + "context" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" +) + +const ( + saTokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" + saNamespacePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" + saCACertPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" + + leaseAPIPath = "/apis/coordination.k8s.io/v1" +) + +// ErrConflict is returned when a Lease update fails due to a +// resourceVersion mismatch (another writer updated the object first). +var ErrConflict = errors.New("conflict: resource version mismatch") + +// Lease represents a coordination.k8s.io/v1 Lease object with only the +// fields needed for distributed locking. +type Lease struct { + APIVersion string `json:"apiVersion"` + Kind string `json:"kind"` + Metadata LeaseMetadata `json:"metadata"` + Spec LeaseSpec `json:"spec"` +} + +// LeaseMetadata holds the standard k8s object metadata fields used by Leases. +type LeaseMetadata struct { + Name string `json:"name"` + Namespace string `json:"namespace,omitempty"` + ResourceVersion string `json:"resourceVersion,omitempty"` + Annotations map[string]string `json:"annotations,omitempty"` +} + +// LeaseSpec holds the Lease specification fields. +type LeaseSpec struct { + HolderIdentity *string `json:"holderIdentity"` + LeaseDurationSeconds *int32 `json:"leaseDurationSeconds,omitempty"` + AcquireTime *MicroTime `json:"acquireTime"` + RenewTime *MicroTime `json:"renewTime"` +} + +// MicroTime wraps time.Time with Kubernetes MicroTime JSON formatting. +type MicroTime struct { + time.Time +} + +const microTimeFormat = "2006-01-02T15:04:05.000000Z" + +// MarshalJSON implements json.Marshaler with k8s MicroTime format. +func (t MicroTime) MarshalJSON() ([]byte, error) { + return json.Marshal(t.UTC().Format(microTimeFormat)) +} + +// UnmarshalJSON implements json.Unmarshaler with k8s MicroTime format. +func (t *MicroTime) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + if s == "" { + t.Time = time.Time{} + return nil + } + + parsed, err := time.Parse(microTimeFormat, s) + if err != nil { + return fmt.Errorf("parse MicroTime %q: %w", s, err) + } + t.Time = parsed + return nil +} + +// LeaseClient talks to the Kubernetes coordination API using raw HTTP. +type LeaseClient struct { + baseURL string + namespace string + httpClient *http.Client +} + +// NewLeaseClient creates a client that authenticates via the pod's +// mounted service account. It reads the namespace and CA certificate +// at construction time (they don't rotate) but reads the bearer token +// fresh on each request (tokens rotate). +func NewLeaseClient() (*LeaseClient, error) { + host := os.Getenv("KUBERNETES_SERVICE_HOST") + port := os.Getenv("KUBERNETES_SERVICE_PORT") + if host == "" || port == "" { + return nil, fmt.Errorf("KUBERNETES_SERVICE_HOST/PORT not set") + } + + ns, err := os.ReadFile(saNamespacePath) + if err != nil { + return nil, fmt.Errorf("read namespace from %s: %w", saNamespacePath, err) + } + + caCert, err := os.ReadFile(saCACertPath) + if err != nil { + return nil, fmt.Errorf("read CA cert from %s: %w", saCACertPath, err) + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("parse CA certificate from %s", saCACertPath) + } + + return &LeaseClient{ + baseURL: fmt.Sprintf("https://%s:%s", host, port), + namespace: strings.TrimSpace(string(ns)), + httpClient: &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: pool, + }, + }, + }, + }, nil +} + +// Namespace returns the namespace this client operates in. +func (c *LeaseClient) Namespace() string { + return c.namespace +} + +// Get retrieves a Lease by name. Returns (nil, nil) if the Lease does not exist. +func (c *LeaseClient) Get(ctx context.Context, name string) (*Lease, error) { + url := fmt.Sprintf("%s%s/namespaces/%s/leases/%s", c.baseURL, leaseAPIPath, c.namespace, name) + + resp, err := c.doRequest(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return nil, nil + } + if resp.StatusCode != http.StatusOK { + return nil, c.readError(resp) + } + + var lease Lease + if err := json.NewDecoder(resp.Body).Decode(&lease); err != nil { + return nil, fmt.Errorf("decode lease response: %w", err) + } + return &lease, nil +} + +// Create creates a new Lease. Returns the created Lease with server-assigned +// fields like resourceVersion populated. +func (c *LeaseClient) Create(ctx context.Context, lease *Lease) (*Lease, error) { + url := fmt.Sprintf("%s%s/namespaces/%s/leases", c.baseURL, leaseAPIPath, c.namespace) + + lease.APIVersion = "coordination.k8s.io/v1" + lease.Kind = "Lease" + if lease.Metadata.Namespace == "" { + lease.Metadata.Namespace = c.namespace + } + + resp, err := c.doRequest(ctx, http.MethodPost, url, lease) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusConflict { + return nil, ErrConflict + } + if resp.StatusCode != http.StatusCreated { + return nil, c.readError(resp) + } + + var created Lease + if err := json.NewDecoder(resp.Body).Decode(&created); err != nil { + return nil, fmt.Errorf("decode created lease: %w", err) + } + return &created, nil +} + +// Update replaces a Lease. The lease.Metadata.ResourceVersion must match +// the current server value (optimistic concurrency). Returns ErrConflict +// on version mismatch. +func (c *LeaseClient) Update(ctx context.Context, lease *Lease) (*Lease, error) { + url := fmt.Sprintf("%s%s/namespaces/%s/leases/%s", c.baseURL, leaseAPIPath, c.namespace, lease.Metadata.Name) + + lease.APIVersion = "coordination.k8s.io/v1" + lease.Kind = "Lease" + + resp, err := c.doRequest(ctx, http.MethodPut, url, lease) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusConflict { + return nil, ErrConflict + } + if resp.StatusCode != http.StatusOK { + return nil, c.readError(resp) + } + + var updated Lease + if err := json.NewDecoder(resp.Body).Decode(&updated); err != nil { + return nil, fmt.Errorf("decode updated lease: %w", err) + } + return &updated, nil +} + +func (c *LeaseClient) doRequest(ctx context.Context, method, url string, body any) (*http.Response, error) { + token, err := readToken() + if err != nil { + return nil, fmt.Errorf("read service account token: %w", err) + } + + var bodyReader io.Reader + if body != nil { + data, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("marshal request body: %w", err) + } + bodyReader = bytes.NewReader(data) + } + + req, err := http.NewRequestWithContext(ctx, method, url, bodyReader) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Accept", "application/json") + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + + return c.httpClient.Do(req) +} + +func readToken() (string, error) { + data, err := os.ReadFile(saTokenPath) + if err != nil { + return "", fmt.Errorf("read %s: %w", saTokenPath, err) + } + return strings.TrimSpace(string(data)), nil +} + +func (c *LeaseClient) readError(resp *http.Response) error { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return fmt.Errorf("k8s API %s %d: %s", resp.Request.URL.Path, resp.StatusCode, string(body)) +} + +// LeaseNameForDomain returns a deterministic, DNS-label-safe Lease name +// for the given domain. The domain is hashed to avoid dots and length issues. +func LeaseNameForDomain(domain string) string { + h := sha256.Sum256([]byte(domain)) + return "cert-lock-" + hex.EncodeToString(h[:8]) +} + +// InCluster reports whether the process is running inside a Kubernetes pod +// by checking for the KUBERNETES_SERVICE_HOST environment variable. +func InCluster() bool { + _, exists := os.LookupEnv("KUBERNETES_SERVICE_HOST") + return exists +} diff --git a/proxy/internal/k8s/lease_test.go b/proxy/internal/k8s/lease_test.go new file mode 100644 index 000000000..9d5d3c6ce --- /dev/null +++ b/proxy/internal/k8s/lease_test.go @@ -0,0 +1,102 @@ +package k8s + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLeaseNameForDomain(t *testing.T) { + tests := []struct { + domain string + }{ + {"example.com"}, + {"app.example.com"}, + {"another.domain.io"}, + } + + seen := make(map[string]string) + for _, tc := range tests { + name := LeaseNameForDomain(tc.domain) + + assert.True(t, len(name) <= 63, "must be valid DNS label length") + assert.Regexp(t, `^cert-lock-[0-9a-f]{16}$`, name, + "must match expected format for domain %q", tc.domain) + + // Same input produces same output. + assert.Equal(t, name, LeaseNameForDomain(tc.domain), "must be deterministic") + + // Different domains produce different names. + if prev, ok := seen[name]; ok { + t.Errorf("collision: %q and %q both map to %s", prev, tc.domain, name) + } + seen[name] = tc.domain + } +} + +func TestMicroTimeJSON(t *testing.T) { + ts := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC) + mt := &MicroTime{Time: ts} + + data, err := json.Marshal(mt) + require.NoError(t, err) + assert.Equal(t, `"2024-06-15T10:30:00.000000Z"`, string(data)) + + var decoded MicroTime + require.NoError(t, json.Unmarshal(data, &decoded)) + assert.True(t, ts.Equal(decoded.Time), "round-trip should preserve time") +} + +func TestMicroTimeNullJSON(t *testing.T) { + // Null pointer serializes as JSON null via the Lease struct. + spec := LeaseSpec{ + HolderIdentity: nil, + AcquireTime: nil, + RenewTime: nil, + } + + data, err := json.Marshal(spec) + require.NoError(t, err) + assert.Contains(t, string(data), `"acquireTime":null`) + assert.Contains(t, string(data), `"renewTime":null`) +} + +func TestLeaseJSONRoundTrip(t *testing.T) { + holder := "pod-abc" + dur := int32(300) + now := MicroTime{Time: time.Now().UTC().Truncate(time.Microsecond)} + + original := Lease{ + APIVersion: "coordination.k8s.io/v1", + Kind: "Lease", + Metadata: LeaseMetadata{ + Name: "cert-lock-abcdef0123456789", + Namespace: "default", + ResourceVersion: "12345", + Annotations: map[string]string{ + "netbird.io/domain": "app.example.com", + }, + }, + Spec: LeaseSpec{ + HolderIdentity: &holder, + LeaseDurationSeconds: &dur, + AcquireTime: &now, + RenewTime: &now, + }, + } + + data, err := json.Marshal(original) + require.NoError(t, err) + + var decoded Lease + require.NoError(t, json.Unmarshal(data, &decoded)) + + assert.Equal(t, original.Metadata.Name, decoded.Metadata.Name) + assert.Equal(t, original.Metadata.ResourceVersion, decoded.Metadata.ResourceVersion) + assert.Equal(t, *original.Spec.HolderIdentity, *decoded.Spec.HolderIdentity) + assert.Equal(t, *original.Spec.LeaseDurationSeconds, *decoded.Spec.LeaseDurationSeconds) + assert.True(t, original.Spec.AcquireTime.Equal(decoded.Spec.AcquireTime.Time)) +} diff --git a/proxy/server.go b/proxy/server.go index dc08ce9e1..c981e6f84 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -32,6 +32,7 @@ import ( "github.com/netbirdio/netbird/proxy/internal/accesslog" "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/proxy/internal/auth" + "github.com/netbirdio/netbird/proxy/internal/certwatch" "github.com/netbirdio/netbird/proxy/internal/debug" proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" "github.com/netbirdio/netbird/proxy/internal/health" @@ -65,6 +66,8 @@ type Server struct { ProxyURL string ManagementAddress string CertificateDirectory string + CertificateFile string + CertificateKeyFile string GenerateACMECertificates bool ACMEChallengeAddress string ACMEDirectory string @@ -210,16 +213,17 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { "ServerName": s.ProxyURL, }).Debug("started ACME challenge server") } else { - s.Logger.Debug("ACME certificates disabled, using static certificates") - // Otherwise pull some certificates from expected locations. - cert, err := tls.LoadX509KeyPair( - filepath.Join(s.CertificateDirectory, "tls.crt"), - filepath.Join(s.CertificateDirectory, "tls.key"), - ) + s.Logger.Debug("ACME certificates disabled, using static certificates with file watching") + certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile) + keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile) + + certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger) if err != nil { - return fmt.Errorf("load provided certificate: %w", err) + return fmt.Errorf("initialize certificate watcher: %w", err) } - tlsConfig.Certificates = append(tlsConfig.Certificates, cert) + go certWatcher.Watch(ctx) + + tlsConfig.GetCertificate = certWatcher.GetCertificate } // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.