mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Add cert hot reload and cert file locking
Adds file-watching certificate hot reload, cross-replica ACME certificate lock coordination via flock (Unix) and Kubernetes lease objects.
This commit is contained in:
@@ -45,6 +45,8 @@ var (
|
||||
oidcScopes string
|
||||
forwardedProto string
|
||||
trustedProxies string
|
||||
certFile string
|
||||
certKeyFile string
|
||||
)
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
@@ -74,6 +76,8 @@ func init() {
|
||||
rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated")
|
||||
rootCmd.Flags().StringVar(&forwardedProto, "forwarded-proto", envStringOrDefault("NB_PROXY_FORWARDED_PROTO", "auto"), "X-Forwarded-Proto value for backends: auto, http, or https")
|
||||
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')")
|
||||
rootCmd.Flags().StringVar(&certFile, "cert-file", envStringOrDefault("NB_PROXY_CERTIFICATE_FILE", "tls.crt"), "TLS certificate filename within the certificate directory")
|
||||
rootCmd.Flags().StringVar(&certKeyFile, "cert-key-file", envStringOrDefault("NB_PROXY_CERTIFICATE_KEY_FILE", "tls.key"), "TLS certificate key filename within the certificate directory")
|
||||
}
|
||||
|
||||
// Execute runs the root command.
|
||||
@@ -127,6 +131,8 @@ func runServer(cmd *cobra.Command, args []string) error {
|
||||
ProxyURL: proxyURL,
|
||||
ProxyToken: proxyToken,
|
||||
CertificateDirectory: certDir,
|
||||
CertificateFile: certFile,
|
||||
CertificateKeyFile: certKeyFile,
|
||||
GenerateACMECertificates: acmeCerts,
|
||||
ACMEChallengeAddress: acmeAddr,
|
||||
ACMEDirectory: acmeDir,
|
||||
|
||||
99
proxy/internal/acme/locker.go
Normal file
99
proxy/internal/acme/locker.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/flock"
|
||||
"github.com/netbirdio/netbird/proxy/internal/k8s"
|
||||
)
|
||||
|
||||
// certLocker provides distributed mutual exclusion for certificate operations.
|
||||
// Implementations must be safe for concurrent use from multiple goroutines.
|
||||
type certLocker interface {
|
||||
// Lock acquires an exclusive lock for the given domain.
|
||||
// It blocks until the lock is acquired, the context is cancelled, or an
|
||||
// unrecoverable error occurs. The returned function releases the lock;
|
||||
// callers must call it exactly once when the critical section is complete.
|
||||
Lock(ctx context.Context, domain string) (unlock func(), err error)
|
||||
}
|
||||
|
||||
// CertLockMethod controls how ACME certificate locks are coordinated.
|
||||
type CertLockMethod string
|
||||
|
||||
const (
|
||||
// CertLockAuto detects the environment and selects k8s-lease if running
|
||||
// in a Kubernetes pod, otherwise flock.
|
||||
CertLockAuto CertLockMethod = "auto"
|
||||
// CertLockFlock uses advisory file locks via flock(2).
|
||||
CertLockFlock CertLockMethod = "flock"
|
||||
// CertLockK8sLease uses Kubernetes coordination Leases.
|
||||
CertLockK8sLease CertLockMethod = "k8s-lease"
|
||||
)
|
||||
|
||||
func newCertLocker(method CertLockMethod, certDir string, logger *log.Logger) certLocker {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
|
||||
if method == "" || method == CertLockAuto {
|
||||
if k8s.InCluster() {
|
||||
method = CertLockK8sLease
|
||||
} else {
|
||||
method = CertLockFlock
|
||||
}
|
||||
logger.Infof("auto-detected cert lock method: %s", method)
|
||||
}
|
||||
|
||||
switch method {
|
||||
case CertLockK8sLease:
|
||||
locker, err := newK8sLeaseLocker(logger)
|
||||
if err != nil {
|
||||
logger.Warnf("create k8s lease locker, falling back to flock: %v", err)
|
||||
return newFlockLocker(certDir, logger)
|
||||
}
|
||||
logger.Infof("using k8s lease locker in namespace %s", locker.client.Namespace())
|
||||
return locker
|
||||
default:
|
||||
logger.Infof("using flock cert locker in %s", certDir)
|
||||
return newFlockLocker(certDir, logger)
|
||||
}
|
||||
}
|
||||
|
||||
type flockLocker struct {
|
||||
certDir string
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func newFlockLocker(certDir string, logger *log.Logger) *flockLocker {
|
||||
return &flockLocker{certDir: certDir, logger: logger}
|
||||
}
|
||||
|
||||
// Lock acquires an advisory file lock for the given domain.
|
||||
func (l *flockLocker) Lock(ctx context.Context, domain string) (func(), error) {
|
||||
lockPath := filepath.Join(l.certDir, domain+".lock")
|
||||
lockFile, err := flock.Lock(ctx, lockPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// nil lockFile means locking is not supported (non-unix).
|
||||
if lockFile == nil {
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
return func() {
|
||||
if err := flock.Unlock(lockFile); err != nil {
|
||||
l.logger.Debugf("release cert lock for domain %q: %v", domain, err)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
type noopLocker struct{}
|
||||
|
||||
// Lock is a no-op that always succeeds immediately.
|
||||
func (noopLocker) Lock(context.Context, string) (func(), error) {
|
||||
return func() {}, nil
|
||||
}
|
||||
197
proxy/internal/acme/locker_k8s.go
Normal file
197
proxy/internal/acme/locker_k8s.go
Normal file
@@ -0,0 +1,197 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/k8s"
|
||||
)
|
||||
|
||||
const (
|
||||
// leaseDurationSec is the Kubernetes Lease TTL. If the holder crashes without
|
||||
// releasing the lock, other replicas must wait this long before taking over.
|
||||
// This is intentionally generous: in the worst case two replicas may both
|
||||
// issue an ACME request for the same domain, which is harmless (the CA
|
||||
// deduplicates and the cache converges).
|
||||
leaseDurationSec = 300
|
||||
retryBaseBackoff = 500 * time.Millisecond
|
||||
retryMaxBackoff = 10 * time.Second
|
||||
)
|
||||
|
||||
type k8sLeaseLocker struct {
|
||||
client *k8s.LeaseClient
|
||||
identity string
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func newK8sLeaseLocker(logger *log.Logger) (*k8sLeaseLocker, error) {
|
||||
client, err := k8s.NewLeaseClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create k8s lease client: %w", err)
|
||||
}
|
||||
|
||||
identity, err := os.Hostname()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get hostname: %w", err)
|
||||
}
|
||||
|
||||
return &k8sLeaseLocker{
|
||||
client: client,
|
||||
identity: identity,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Lock acquires a Kubernetes Lease for the given domain using optimistic
|
||||
// concurrency. It retries with exponential backoff until the lease is
|
||||
// acquired or the context is cancelled.
|
||||
func (l *k8sLeaseLocker) Lock(ctx context.Context, domain string) (func(), error) {
|
||||
leaseName := k8s.LeaseNameForDomain(domain)
|
||||
backoff := retryBaseBackoff
|
||||
|
||||
for {
|
||||
acquired, err := l.tryAcquire(ctx, leaseName, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("acquire lease %s for %q: %w", leaseName, domain, err)
|
||||
}
|
||||
if acquired {
|
||||
l.logger.Debugf("k8s lease %s acquired for domain %q", leaseName, domain)
|
||||
return l.unlockFunc(leaseName, domain), nil
|
||||
}
|
||||
|
||||
l.logger.Debugf("k8s lease %s held by another replica, retrying in %s", leaseName, backoff)
|
||||
|
||||
timer := time.NewTimer(backoff)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil, ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
backoff *= 2
|
||||
if backoff > retryMaxBackoff {
|
||||
backoff = retryMaxBackoff
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryAcquire attempts to create or take over a Lease. Returns (true, nil)
|
||||
// on success, (false, nil) if the lease is held and not stale, or an error.
|
||||
func (l *k8sLeaseLocker) tryAcquire(ctx context.Context, name, domain string) (bool, error) {
|
||||
existing, err := l.client.Get(ctx, name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
now := k8s.MicroTime{Time: time.Now().UTC()}
|
||||
dur := int32(leaseDurationSec)
|
||||
|
||||
if existing == nil {
|
||||
lease := &k8s.Lease{
|
||||
Metadata: k8s.LeaseMetadata{
|
||||
Name: name,
|
||||
Annotations: map[string]string{
|
||||
"netbird.io/domain": domain,
|
||||
},
|
||||
},
|
||||
Spec: k8s.LeaseSpec{
|
||||
HolderIdentity: &l.identity,
|
||||
LeaseDurationSeconds: &dur,
|
||||
AcquireTime: &now,
|
||||
RenewTime: &now,
|
||||
},
|
||||
}
|
||||
|
||||
if _, err := l.client.Create(ctx, lease); errors.Is(err, k8s.ErrConflict) {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if !l.canTakeover(existing) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
existing.Spec.HolderIdentity = &l.identity
|
||||
existing.Spec.LeaseDurationSeconds = &dur
|
||||
existing.Spec.AcquireTime = &now
|
||||
existing.Spec.RenewTime = &now
|
||||
|
||||
if _, err := l.client.Update(ctx, existing); errors.Is(err, k8s.ErrConflict) {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// canTakeover returns true if the lease is free (no holder) or stale
|
||||
// (renewTime + leaseDuration has passed).
|
||||
func (l *k8sLeaseLocker) canTakeover(lease *k8s.Lease) bool {
|
||||
holder := lease.Spec.HolderIdentity
|
||||
if holder == nil || *holder == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// We already hold it (e.g. from a previous crashed attempt).
|
||||
if *holder == l.identity {
|
||||
return true
|
||||
}
|
||||
|
||||
if lease.Spec.RenewTime == nil || lease.Spec.LeaseDurationSeconds == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
expiry := lease.Spec.RenewTime.Add(time.Duration(*lease.Spec.LeaseDurationSeconds) * time.Second)
|
||||
if time.Now().After(expiry) {
|
||||
l.logger.Infof("k8s lease %s held by %q is stale (expired %s ago), taking over",
|
||||
lease.Metadata.Name, *holder, time.Since(expiry).Round(time.Second))
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// unlockFunc returns a closure that releases the lease by clearing the holder.
|
||||
func (l *k8sLeaseLocker) unlockFunc(name, domain string) func() {
|
||||
return func() {
|
||||
// Use a fresh context: the parent may already be cancelled.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Re-GET to get current resourceVersion (ours may be stale if
|
||||
// the lock was held for a long time and something updated it).
|
||||
current, err := l.client.Get(ctx, name)
|
||||
if err != nil {
|
||||
l.logger.Debugf("release k8s lease %s for %q: get: %v", name, domain, err)
|
||||
return
|
||||
}
|
||||
if current == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Only clear if we're still the holder.
|
||||
if current.Spec.HolderIdentity == nil || *current.Spec.HolderIdentity != l.identity {
|
||||
l.logger.Debugf("k8s lease %s for %q: holder changed to %v, skip release",
|
||||
name, domain, current.Spec.HolderIdentity)
|
||||
return
|
||||
}
|
||||
|
||||
empty := ""
|
||||
current.Spec.HolderIdentity = &empty
|
||||
current.Spec.AcquireTime = nil
|
||||
current.Spec.RenewTime = nil
|
||||
|
||||
if _, err := l.client.Update(ctx, current); err != nil {
|
||||
l.logger.Debugf("release k8s lease %s for %q: update: %v", name, domain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
65
proxy/internal/acme/locker_test.go
Normal file
65
proxy/internal/acme/locker_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFlockLockerRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
locker := newFlockLocker(dir, nil)
|
||||
|
||||
unlock, err := locker.Lock(context.Background(), "example.com")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, unlock)
|
||||
|
||||
// Lock file should exist.
|
||||
assert.FileExists(t, filepath.Join(dir, "example.com.lock"))
|
||||
|
||||
unlock()
|
||||
}
|
||||
|
||||
func TestNoopLocker(t *testing.T) {
|
||||
locker := noopLocker{}
|
||||
unlock, err := locker.Lock(context.Background(), "example.com")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, unlock)
|
||||
unlock()
|
||||
}
|
||||
|
||||
func TestNewCertLockerDefaultsToFlock(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// t.Setenv registers cleanup to restore the original value.
|
||||
// os.Unsetenv is needed because the production code uses LookupEnv,
|
||||
// which distinguishes "empty" from "not set".
|
||||
t.Setenv("KUBERNETES_SERVICE_HOST", "")
|
||||
os.Unsetenv("KUBERNETES_SERVICE_HOST")
|
||||
locker := newCertLocker(CertLockAuto, dir, nil)
|
||||
|
||||
_, ok := locker.(*flockLocker)
|
||||
assert.True(t, ok, "auto without k8s env should select flockLocker")
|
||||
}
|
||||
|
||||
func TestNewCertLockerExplicitFlock(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
locker := newCertLocker(CertLockFlock, dir, nil)
|
||||
|
||||
_, ok := locker.(*flockLocker)
|
||||
assert.True(t, ok, "explicit flock should select flockLocker")
|
||||
}
|
||||
|
||||
func TestNewCertLockerK8sFallsBackToFlock(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// k8s-lease without SA files should fall back to flock.
|
||||
locker := newCertLocker(CertLockK8sLease, dir, nil)
|
||||
|
||||
_, ok := locker.(*flockLocker)
|
||||
assert.True(t, ok, "k8s-lease without SA should fall back to flockLocker")
|
||||
}
|
||||
@@ -23,9 +23,14 @@ type certificateNotifier interface {
|
||||
NotifyCertificateIssued(ctx context.Context, accountID, reverseProxyID, domain string) error
|
||||
}
|
||||
|
||||
// Manager wraps autocert.Manager with domain tracking and cross-replica
|
||||
// coordination via a pluggable locking strategy. The locker prevents
|
||||
// duplicate ACME requests when multiple replicas share a certificate cache.
|
||||
type Manager struct {
|
||||
*autocert.Manager
|
||||
|
||||
certDir string
|
||||
locker certLocker
|
||||
domainsMux sync.RWMutex
|
||||
domains map[string]struct {
|
||||
accountID string
|
||||
@@ -36,11 +41,16 @@ type Manager struct {
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger) *Manager {
|
||||
// NewManager creates a new ACME certificate manager. The certDir is used
|
||||
// for caching certificates. The lockMethod controls cross-replica
|
||||
// coordination strategy (see CertLockMethod constants).
|
||||
func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
mgr := &Manager{
|
||||
certDir: certDir,
|
||||
locker: newCertLocker(lockMethod, certDir, logger),
|
||||
domains: make(map[string]struct {
|
||||
accountID string
|
||||
reverseProxyID string
|
||||
@@ -59,12 +69,15 @@ func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *l
|
||||
return mgr
|
||||
}
|
||||
|
||||
func (mgr *Manager) hostPolicy(ctx context.Context, domain string) error {
|
||||
func (mgr *Manager) hostPolicy(_ context.Context, host string) error {
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
mgr.domainsMux.RLock()
|
||||
_, exists := mgr.domains[domain]
|
||||
_, exists := mgr.domains[host]
|
||||
mgr.domainsMux.RUnlock()
|
||||
if !exists {
|
||||
return fmt.Errorf("unknown domain %q", domain)
|
||||
return fmt.Errorf("unknown domain %q", host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -84,17 +97,31 @@ func (mgr *Manager) AddDomain(domain, accountID, reverseProxyID string) {
|
||||
}
|
||||
|
||||
// prefetchCertificate proactively triggers certificate generation for a domain.
|
||||
// It acquires a distributed lock to prevent multiple replicas from issuing
|
||||
// duplicate ACME requests. The second replica will block until the first
|
||||
// finishes, then find the certificate in the cache.
|
||||
func (mgr *Manager) prefetchCertificate(domain string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
mgr.logger.Infof("acquiring cert lock for domain %q", domain)
|
||||
lockStart := time.Now()
|
||||
unlock, err := mgr.locker.Lock(ctx, domain)
|
||||
if err != nil {
|
||||
mgr.logger.Warnf("acquire cert lock for domain %q, proceeding without lock: %v", domain, err)
|
||||
} else {
|
||||
mgr.logger.Infof("acquired cert lock for domain %q in %s", domain, time.Since(lockStart))
|
||||
defer unlock()
|
||||
}
|
||||
|
||||
hello := &tls.ClientHelloInfo{
|
||||
ServerName: domain,
|
||||
Conn: &dummyConn{ctx: ctx},
|
||||
}
|
||||
|
||||
mgr.logger.Infof("prefetching certificate for domain %q", domain)
|
||||
start := time.Now()
|
||||
cert, err := mgr.GetCertificate(hello)
|
||||
elapsed := time.Since(start)
|
||||
if err != nil {
|
||||
mgr.logger.Warnf("prefetch certificate for domain %q: %v", domain, err)
|
||||
return
|
||||
@@ -102,11 +129,18 @@ func (mgr *Manager) prefetchCertificate(domain string) {
|
||||
|
||||
now := time.Now()
|
||||
if cert != nil && cert.Leaf != nil {
|
||||
mgr.logCertificateDetails(domain, cert.Leaf, now)
|
||||
leaf := cert.Leaf
|
||||
mgr.logger.Infof("certificate for domain %q ready in %s: serial=%s SANs=%v notAfter=%s",
|
||||
domain, elapsed.Round(time.Millisecond),
|
||||
leaf.SerialNumber.Text(16),
|
||||
leaf.DNSNames,
|
||||
leaf.NotAfter.UTC().Format(time.RFC3339),
|
||||
)
|
||||
mgr.logCertificateDetails(domain, leaf, now)
|
||||
} else {
|
||||
mgr.logger.Infof("certificate for domain %q ready in %s", domain, elapsed.Round(time.Millisecond))
|
||||
}
|
||||
|
||||
mgr.logger.Infof("certificate for domain %q is ready", domain)
|
||||
|
||||
mgr.domainsMux.RLock()
|
||||
info, exists := mgr.domains[domain]
|
||||
mgr.domainsMux.RUnlock()
|
||||
@@ -120,18 +154,12 @@ func (mgr *Manager) prefetchCertificate(domain string) {
|
||||
|
||||
// logCertificateDetails logs certificate validity and SCT timestamps.
|
||||
func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate, now time.Time) {
|
||||
mgr.logger.Infof("certificate for %q: NotBefore=%v, NotAfter=%v, now=%v",
|
||||
domain, cert.NotBefore.UTC(), cert.NotAfter.UTC(), now.UTC())
|
||||
|
||||
if cert.NotBefore.After(now) {
|
||||
mgr.logger.Warnf("certificate for %q NotBefore is in the future by %v", domain, cert.NotBefore.Sub(now))
|
||||
} else {
|
||||
mgr.logger.Infof("certificate for %q NotBefore is %v in the past", domain, now.Sub(cert.NotBefore))
|
||||
}
|
||||
|
||||
sctTimestamps := mgr.parseSCTTimestamps(cert)
|
||||
if len(sctTimestamps) == 0 {
|
||||
mgr.logger.Warnf("certificate for %q has no embedded SCTs", domain)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -140,7 +168,7 @@ func (mgr *Manager) logCertificateDetails(domain string, cert *x509.Certificate,
|
||||
mgr.logger.Warnf("certificate for %q SCT[%d] timestamp is in the future: %v (by %v)",
|
||||
domain, i, sctTime.UTC(), sctTime.Sub(now))
|
||||
} else {
|
||||
mgr.logger.Infof("certificate for %q SCT[%d] timestamp: %v (%v in the past)",
|
||||
mgr.logger.Debugf("certificate for %q SCT[%d] timestamp: %v (%v in the past)",
|
||||
domain, i, sctTime.UTC(), now.Sub(sctTime))
|
||||
}
|
||||
}
|
||||
|
||||
61
proxy/internal/acme/manager_test.go
Normal file
61
proxy/internal/acme/manager_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package acme
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHostPolicy(t *testing.T) {
|
||||
mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "")
|
||||
mgr.AddDomain("example.com", "acc1", "rp1")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "exact domain match",
|
||||
host: "example.com",
|
||||
},
|
||||
{
|
||||
name: "domain with port",
|
||||
host: "example.com:443",
|
||||
},
|
||||
{
|
||||
name: "unknown domain",
|
||||
host: "unknown.com",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "unknown domain with port",
|
||||
host: "unknown.com:443",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty host",
|
||||
host: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "port only",
|
||||
host: ":443",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := mgr.hostPolicy(context.Background(), tc.host)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown domain")
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
279
proxy/internal/certwatch/watcher.go
Normal file
279
proxy/internal/certwatch/watcher.go
Normal file
@@ -0,0 +1,279 @@
|
||||
// Package certwatch watches TLS certificate files on disk and provides
|
||||
// a hot-reloading GetCertificate callback for tls.Config.
|
||||
package certwatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPollInterval = 30 * time.Second
|
||||
debounceDelay = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// Watcher monitors TLS certificate files on disk and caches the loaded
|
||||
// certificate in memory. It detects changes via fsnotify (with a polling
|
||||
// fallback for filesystems like NFS that lack inotify support) and
|
||||
// reloads the certificate pair automatically.
|
||||
type Watcher struct {
|
||||
certPath string
|
||||
keyPath string
|
||||
|
||||
mu sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
leaf *x509.Certificate
|
||||
|
||||
pollInterval time.Duration
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewWatcher creates a Watcher that monitors the given cert and key files.
|
||||
// It performs an initial load of the certificate and returns an error
|
||||
// if the initial load fails.
|
||||
func NewWatcher(certPath, keyPath string, logger *log.Logger) (*Watcher, error) {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
|
||||
w := &Watcher{
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
pollInterval: defaultPollInterval,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if err := w.reload(); err != nil {
|
||||
return nil, fmt.Errorf("initial certificate load: %w", err)
|
||||
}
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
// GetCertificate returns the current in-memory certificate.
|
||||
// It is safe for concurrent use and compatible with tls.Config.GetCertificate.
|
||||
func (w *Watcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
w.mu.RLock()
|
||||
defer w.mu.RUnlock()
|
||||
|
||||
return w.cert, nil
|
||||
}
|
||||
|
||||
// Watch starts watching for certificate file changes. It blocks until
|
||||
// ctx is cancelled. It uses fsnotify for immediate detection and falls
|
||||
// back to polling if fsnotify is unavailable (e.g. on NFS).
|
||||
// Even with fsnotify active, a periodic poll runs as a safety net.
|
||||
func (w *Watcher) Watch(ctx context.Context) {
|
||||
// Watch the parent directory rather than individual files. Some volume
|
||||
// mounts use an atomic symlink swap (..data -> timestamped dir), so
|
||||
// watching the parent directory catches the link replacement.
|
||||
certDir := filepath.Dir(w.certPath)
|
||||
keyDir := filepath.Dir(w.keyPath)
|
||||
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
w.logger.Warnf("fsnotify unavailable, using polling only: %v", err)
|
||||
w.pollLoop(ctx)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := watcher.Close(); err != nil {
|
||||
w.logger.Debugf("close fsnotify watcher: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := watcher.Add(certDir); err != nil {
|
||||
w.logger.Warnf("fsnotify watch on %s failed, using polling only: %v", certDir, err)
|
||||
w.pollLoop(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
if keyDir != certDir {
|
||||
if err := watcher.Add(keyDir); err != nil {
|
||||
w.logger.Warnf("fsnotify watch on %s failed: %v", keyDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
w.logger.Infof("watching certificate files in %s", certDir)
|
||||
w.fsnotifyLoop(ctx, watcher)
|
||||
}
|
||||
|
||||
func (w *Watcher) fsnotifyLoop(ctx context.Context, watcher *fsnotify.Watcher) {
|
||||
certBase := filepath.Base(w.certPath)
|
||||
keyBase := filepath.Base(w.keyPath)
|
||||
|
||||
var debounce *time.Timer
|
||||
defer func() {
|
||||
if debounce != nil {
|
||||
debounce.Stop()
|
||||
}
|
||||
}()
|
||||
|
||||
// Periodic poll as a safety net for missed fsnotify events.
|
||||
pollTicker := time.NewTicker(w.pollInterval)
|
||||
defer pollTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
base := filepath.Base(event.Name)
|
||||
if !isRelevantFile(base, certBase, keyBase) {
|
||||
w.logger.Debugf("fsnotify: ignoring event %s on %s", event.Op, event.Name)
|
||||
continue
|
||||
}
|
||||
if !event.Has(fsnotify.Create) && !event.Has(fsnotify.Write) && !event.Has(fsnotify.Rename) {
|
||||
w.logger.Debugf("fsnotify: ignoring op %s on %s", event.Op, base)
|
||||
continue
|
||||
}
|
||||
|
||||
w.logger.Debugf("fsnotify: detected %s on %s, scheduling reload", event.Op, base)
|
||||
|
||||
// Debounce: cert-manager may write cert and key as separate
|
||||
// operations. Wait briefly to load both at once.
|
||||
if debounce != nil {
|
||||
debounce.Stop()
|
||||
}
|
||||
debounce = time.AfterFunc(debounceDelay, func() {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
w.tryReload()
|
||||
})
|
||||
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
w.logger.Warnf("fsnotify error: %v", err)
|
||||
|
||||
case <-pollTicker.C:
|
||||
w.tryReload()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) pollLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(w.pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.tryReload()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reload loads the certificate from disk and updates the in-memory cache.
|
||||
func (w *Watcher) reload() error {
|
||||
cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse the leaf for comparison on subsequent reloads.
|
||||
if cert.Leaf == nil && len(cert.Certificate) > 0 {
|
||||
leaf, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse leaf certificate: %w", err)
|
||||
}
|
||||
cert.Leaf = leaf
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
w.cert = &cert
|
||||
w.leaf = cert.Leaf
|
||||
w.mu.Unlock()
|
||||
|
||||
w.logCertDetails("loaded certificate", cert.Leaf)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tryReload attempts to reload the certificate. It skips the update
|
||||
// if the certificate on disk is identical to the one in memory (same
|
||||
// serial number and issuer) to avoid redundant log noise.
|
||||
func (w *Watcher) tryReload() {
|
||||
cert, err := tls.LoadX509KeyPair(w.certPath, w.keyPath)
|
||||
if err != nil {
|
||||
w.logger.Warnf("reload certificate: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if cert.Leaf == nil && len(cert.Certificate) > 0 {
|
||||
leaf, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
w.logger.Warnf("parse reloaded leaf certificate: %v", err)
|
||||
return
|
||||
}
|
||||
cert.Leaf = leaf
|
||||
}
|
||||
|
||||
w.mu.Lock()
|
||||
|
||||
if w.leaf != nil && cert.Leaf != nil &&
|
||||
w.leaf.SerialNumber.Cmp(cert.Leaf.SerialNumber) == 0 &&
|
||||
w.leaf.Issuer.CommonName == cert.Leaf.Issuer.CommonName {
|
||||
w.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
prev := w.leaf
|
||||
w.cert = &cert
|
||||
w.leaf = cert.Leaf
|
||||
w.mu.Unlock()
|
||||
|
||||
w.logCertChange(prev, cert.Leaf)
|
||||
}
|
||||
|
||||
func (w *Watcher) logCertDetails(msg string, leaf *x509.Certificate) {
|
||||
if leaf == nil {
|
||||
w.logger.Info(msg)
|
||||
return
|
||||
}
|
||||
|
||||
w.logger.Infof("%s: subject=%q serial=%s SANs=%v notAfter=%s",
|
||||
msg,
|
||||
leaf.Subject.CommonName,
|
||||
leaf.SerialNumber.Text(16),
|
||||
leaf.DNSNames,
|
||||
leaf.NotAfter.UTC().Format(time.RFC3339),
|
||||
)
|
||||
}
|
||||
|
||||
func (w *Watcher) logCertChange(prev, next *x509.Certificate) {
|
||||
if prev == nil || next == nil {
|
||||
w.logCertDetails("certificate reloaded from disk", next)
|
||||
return
|
||||
}
|
||||
|
||||
w.logger.Infof("certificate reloaded from disk: subject=%q -> %q serial=%s -> %s notAfter=%s -> %s",
|
||||
prev.Subject.CommonName, next.Subject.CommonName,
|
||||
prev.SerialNumber.Text(16), next.SerialNumber.Text(16),
|
||||
prev.NotAfter.UTC().Format(time.RFC3339), next.NotAfter.UTC().Format(time.RFC3339),
|
||||
)
|
||||
}
|
||||
|
||||
// isRelevantFile returns true if the changed file name is one we care about.
|
||||
// This includes the cert/key files themselves and the ..data symlink used
|
||||
// by atomic volume mounts.
|
||||
func isRelevantFile(changed, certBase, keyBase string) bool {
|
||||
return changed == certBase || changed == keyBase || changed == "..data"
|
||||
}
|
||||
292
proxy/internal/certwatch/watcher_test.go
Normal file
292
proxy/internal/certwatch/watcher_test.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package certwatch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func generateSelfSignedCert(t *testing.T, serial int64) (certPEM, keyPEM []byte) {
|
||||
t.Helper()
|
||||
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
template := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(serial),
|
||||
Subject: pkix.Name{CommonName: "test"},
|
||||
NotBefore: time.Now().Add(-time.Hour),
|
||||
NotAfter: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
|
||||
require.NoError(t, err)
|
||||
|
||||
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
require.NoError(t, err)
|
||||
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
|
||||
return certPEM, keyPEM
|
||||
}
|
||||
|
||||
func writeCert(t *testing.T, dir string, certPEM, keyPEM []byte) {
|
||||
t.Helper()
|
||||
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "tls.crt"), certPEM, 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir, "tls.key"), keyPEM, 0o600))
|
||||
}
|
||||
|
||||
func TestNewWatcher(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM, keyPEM := generateSelfSignedCert(t, 1)
|
||||
writeCert(t, dir, certPEM, keyPEM)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cert)
|
||||
assert.Equal(t, int64(1), cert.Leaf.SerialNumber.Int64())
|
||||
}
|
||||
|
||||
func TestNewWatcherMissingFiles(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
_, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestReload(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM1, keyPEM1 := generateSelfSignedCert(t, 100)
|
||||
writeCert(t, dir, certPEM1, keyPEM1)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert1, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(100), cert1.Leaf.SerialNumber.Int64())
|
||||
|
||||
// Write a new cert with a different serial.
|
||||
certPEM2, keyPEM2 := generateSelfSignedCert(t, 200)
|
||||
writeCert(t, dir, certPEM2, keyPEM2)
|
||||
|
||||
// Manually trigger reload.
|
||||
w.tryReload()
|
||||
|
||||
cert2, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(200), cert2.Leaf.SerialNumber.Int64())
|
||||
}
|
||||
|
||||
func TestTryReloadSkipsUnchanged(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM, keyPEM := generateSelfSignedCert(t, 42)
|
||||
writeCert(t, dir, certPEM, keyPEM)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert1, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reload with same cert - pointer should remain the same.
|
||||
w.tryReload()
|
||||
|
||||
cert2, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Same(t, cert1, cert2, "cert pointer should not change when content is the same")
|
||||
}
|
||||
|
||||
func TestWatchDetectsChanges(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
|
||||
writeCert(t, dir, certPEM1, keyPEM1)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Use a short poll interval for the test.
|
||||
w.pollInterval = 100 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go w.Watch(ctx)
|
||||
|
||||
// Write new cert.
|
||||
certPEM2, keyPEM2 := generateSelfSignedCert(t, 999)
|
||||
writeCert(t, dir, certPEM2, keyPEM2)
|
||||
|
||||
// Wait for the watcher to pick it up.
|
||||
require.Eventually(t, func() bool {
|
||||
cert, err := w.GetCertificate(nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cert.Leaf.SerialNumber.Int64() == 999
|
||||
}, 5*time.Second, 50*time.Millisecond, "watcher should detect cert change")
|
||||
}
|
||||
|
||||
func TestIsRelevantFile(t *testing.T) {
|
||||
assert.True(t, isRelevantFile("tls.crt", "tls.crt", "tls.key"))
|
||||
assert.True(t, isRelevantFile("tls.key", "tls.crt", "tls.key"))
|
||||
assert.True(t, isRelevantFile("..data", "tls.crt", "tls.key"))
|
||||
assert.False(t, isRelevantFile("other.txt", "tls.crt", "tls.key"))
|
||||
}
|
||||
|
||||
// TestWatchSymlinkRotation simulates Kubernetes secret volume updates where
|
||||
// the data directory is atomically swapped via a ..data symlink.
|
||||
func TestWatchSymlinkRotation(t *testing.T) {
|
||||
base := t.TempDir()
|
||||
|
||||
// Create initial target directory with certs.
|
||||
dir1 := filepath.Join(base, "dir1")
|
||||
require.NoError(t, os.Mkdir(dir1, 0o755))
|
||||
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir1, "tls.crt"), certPEM1, 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir1, "tls.key"), keyPEM1, 0o600))
|
||||
|
||||
// Create ..data symlink pointing to dir1.
|
||||
dataLink := filepath.Join(base, "..data")
|
||||
require.NoError(t, os.Symlink(dir1, dataLink))
|
||||
|
||||
// Create tls.crt and tls.key as symlinks to ..data/{file}.
|
||||
certLink := filepath.Join(base, "tls.crt")
|
||||
keyLink := filepath.Join(base, "tls.key")
|
||||
require.NoError(t, os.Symlink(filepath.Join(dataLink, "tls.crt"), certLink))
|
||||
require.NoError(t, os.Symlink(filepath.Join(dataLink, "tls.key"), keyLink))
|
||||
|
||||
w, err := NewWatcher(certLink, keyLink, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
cert, err := w.GetCertificate(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), cert.Leaf.SerialNumber.Int64())
|
||||
|
||||
w.pollInterval = 100 * time.Millisecond
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go w.Watch(ctx)
|
||||
|
||||
// Simulate k8s atomic rotation: create dir2, swap ..data symlink.
|
||||
dir2 := filepath.Join(base, "dir2")
|
||||
require.NoError(t, os.Mkdir(dir2, 0o755))
|
||||
certPEM2, keyPEM2 := generateSelfSignedCert(t, 777)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir2, "tls.crt"), certPEM2, 0o600))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(dir2, "tls.key"), keyPEM2, 0o600))
|
||||
|
||||
// Atomic swap: create temp link, then rename over ..data.
|
||||
tmpLink := filepath.Join(base, "..data_tmp")
|
||||
require.NoError(t, os.Symlink(dir2, tmpLink))
|
||||
require.NoError(t, os.Rename(tmpLink, dataLink))
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
cert, err := w.GetCertificate(nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cert.Leaf.SerialNumber.Int64() == 777
|
||||
}, 5*time.Second, 50*time.Millisecond, "watcher should detect symlink rotation")
|
||||
}
|
||||
|
||||
// TestPollLoopDetectsChanges verifies the poll-only fallback path works.
|
||||
func TestPollLoopDetectsChanges(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM1, keyPEM1 := generateSelfSignedCert(t, 1)
|
||||
writeCert(t, dir, certPEM1, keyPEM1)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
w.pollInterval = 100 * time.Millisecond
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Directly use pollLoop to test the fallback path.
|
||||
go w.pollLoop(ctx)
|
||||
|
||||
certPEM2, keyPEM2 := generateSelfSignedCert(t, 555)
|
||||
writeCert(t, dir, certPEM2, keyPEM2)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
cert, err := w.GetCertificate(nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return cert.Leaf.SerialNumber.Int64() == 555
|
||||
}, 5*time.Second, 50*time.Millisecond, "poll loop should detect cert change")
|
||||
}
|
||||
|
||||
func TestGetCertificateConcurrency(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
certPEM, keyPEM := generateSelfSignedCert(t, 1)
|
||||
writeCert(t, dir, certPEM, keyPEM)
|
||||
|
||||
w, err := NewWatcher(
|
||||
filepath.Join(dir, "tls.crt"),
|
||||
filepath.Join(dir, "tls.key"),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Hammer GetCertificate concurrently while reloading.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
w.tryReload()
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
cert, err := w.GetCertificate(&tls.ClientHelloInfo{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, cert)
|
||||
}
|
||||
|
||||
<-done
|
||||
}
|
||||
20
proxy/internal/flock/flock_other.go
Normal file
20
proxy/internal/flock/flock_other.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build !unix
|
||||
|
||||
package flock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Lock is a no-op on non-Unix platforms. Returns (nil, nil) to indicate
|
||||
// that no lock was acquired; callers must treat a nil file as "proceed
|
||||
// without lock" rather than "lock held by someone else."
|
||||
func Lock(_ context.Context, _ string) (*os.File, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Unlock is a no-op on non-Unix platforms.
|
||||
func Unlock(_ *os.File) error {
|
||||
return nil
|
||||
}
|
||||
79
proxy/internal/flock/flock_test.go
Normal file
79
proxy/internal/flock/flock_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
//go:build unix
|
||||
|
||||
package flock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLockUnlock(t *testing.T) {
|
||||
lockPath := filepath.Join(t.TempDir(), "test.lock")
|
||||
|
||||
f, err := Lock(context.Background(), lockPath)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, f)
|
||||
|
||||
_, err = os.Stat(lockPath)
|
||||
assert.NoError(t, err, "lock file should exist")
|
||||
|
||||
err = Unlock(f)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestUnlockNil(t *testing.T) {
|
||||
err := Unlock(nil)
|
||||
assert.NoError(t, err, "unlocking nil should be a no-op")
|
||||
}
|
||||
|
||||
func TestLockRespectsContext(t *testing.T) {
|
||||
lockPath := filepath.Join(t.TempDir(), "test.lock")
|
||||
|
||||
f1, err := Lock(context.Background(), lockPath)
|
||||
require.NoError(t, err)
|
||||
defer func() { require.NoError(t, Unlock(f1)) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err = Lock(ctx, lockPath)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestLockBlocks(t *testing.T) {
|
||||
lockPath := filepath.Join(t.TempDir(), "test.lock")
|
||||
|
||||
f1, err := Lock(context.Background(), lockPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
start := time.Now()
|
||||
var elapsed time.Duration
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
f2, err := Lock(context.Background(), lockPath)
|
||||
elapsed = time.Since(start)
|
||||
assert.NoError(t, err)
|
||||
if f2 != nil {
|
||||
assert.NoError(t, Unlock(f2))
|
||||
}
|
||||
}()
|
||||
|
||||
// Hold the lock for 200ms, then release.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
require.NoError(t, Unlock(f1))
|
||||
|
||||
wg.Wait()
|
||||
assert.GreaterOrEqual(t, elapsed, 150*time.Millisecond,
|
||||
"Lock should have blocked for at least ~200ms")
|
||||
}
|
||||
69
proxy/internal/flock/flock_unix.go
Normal file
69
proxy/internal/flock/flock_unix.go
Normal file
@@ -0,0 +1,69 @@
|
||||
//go:build unix
|
||||
|
||||
// Package flock provides best-effort advisory file locking using flock(2).
|
||||
//
|
||||
// This is used for cross-replica coordination (e.g. preventing duplicate
|
||||
// ACME requests). Note that flock(2) does NOT work reliably on NFS volumes:
|
||||
// on NFSv3 it depends on the NLM daemon, on NFSv4 Linux emulates it via
|
||||
// fcntl locks with different semantics. Callers must treat lock failures
|
||||
// as non-fatal and proceed without the lock.
|
||||
package flock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const retryInterval = 100 * time.Millisecond
|
||||
|
||||
// Lock acquires an exclusive advisory lock on the given file path.
|
||||
// It creates the lock file if it does not exist. The lock attempt
|
||||
// respects context cancellation by using non-blocking flock with polling.
|
||||
// The caller must call Unlock with the returned *os.File when done.
|
||||
func Lock(ctx context.Context, path string) (*os.File, error) {
|
||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o600)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open lock file %s: %w", path, err)
|
||||
}
|
||||
|
||||
for {
|
||||
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err == nil {
|
||||
return f, nil
|
||||
} else if !errors.Is(err, syscall.EWOULDBLOCK) {
|
||||
if cerr := f.Close(); cerr != nil {
|
||||
log.Debugf("close lock file %s: %v", path, cerr)
|
||||
}
|
||||
return nil, fmt.Errorf("acquire lock on %s: %w", path, err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if cerr := f.Close(); cerr != nil {
|
||||
log.Debugf("close lock file %s: %v", path, cerr)
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(retryInterval):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unlock releases the lock and closes the file.
|
||||
func Unlock(f *os.File) error {
|
||||
if f == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
if err := syscall.Flock(int(f.Fd()), syscall.LOCK_UN); err != nil {
|
||||
return fmt.Errorf("release lock: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
282
proxy/internal/k8s/lease.go
Normal file
282
proxy/internal/k8s/lease.go
Normal file
@@ -0,0 +1,282 @@
|
||||
// Package k8s provides a lightweight Kubernetes API client for coordination
|
||||
// Leases. It uses raw HTTP calls against the mounted service account
|
||||
// credentials, avoiding a dependency on client-go.
|
||||
package k8s
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
saTokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token"
|
||||
saNamespacePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
|
||||
saCACertPath = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt"
|
||||
|
||||
leaseAPIPath = "/apis/coordination.k8s.io/v1"
|
||||
)
|
||||
|
||||
// ErrConflict is returned when a Lease update fails due to a
|
||||
// resourceVersion mismatch (another writer updated the object first).
|
||||
var ErrConflict = errors.New("conflict: resource version mismatch")
|
||||
|
||||
// Lease represents a coordination.k8s.io/v1 Lease object with only the
|
||||
// fields needed for distributed locking.
|
||||
type Lease struct {
|
||||
APIVersion string `json:"apiVersion"`
|
||||
Kind string `json:"kind"`
|
||||
Metadata LeaseMetadata `json:"metadata"`
|
||||
Spec LeaseSpec `json:"spec"`
|
||||
}
|
||||
|
||||
// LeaseMetadata holds the standard k8s object metadata fields used by Leases.
|
||||
type LeaseMetadata struct {
|
||||
Name string `json:"name"`
|
||||
Namespace string `json:"namespace,omitempty"`
|
||||
ResourceVersion string `json:"resourceVersion,omitempty"`
|
||||
Annotations map[string]string `json:"annotations,omitempty"`
|
||||
}
|
||||
|
||||
// LeaseSpec holds the Lease specification fields.
|
||||
type LeaseSpec struct {
|
||||
HolderIdentity *string `json:"holderIdentity"`
|
||||
LeaseDurationSeconds *int32 `json:"leaseDurationSeconds,omitempty"`
|
||||
AcquireTime *MicroTime `json:"acquireTime"`
|
||||
RenewTime *MicroTime `json:"renewTime"`
|
||||
}
|
||||
|
||||
// MicroTime wraps time.Time with Kubernetes MicroTime JSON formatting.
|
||||
type MicroTime struct {
|
||||
time.Time
|
||||
}
|
||||
|
||||
const microTimeFormat = "2006-01-02T15:04:05.000000Z"
|
||||
|
||||
// MarshalJSON implements json.Marshaler with k8s MicroTime format.
|
||||
func (t MicroTime) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(t.UTC().Format(microTimeFormat))
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler with k8s MicroTime format.
|
||||
func (t *MicroTime) UnmarshalJSON(data []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
if s == "" {
|
||||
t.Time = time.Time{}
|
||||
return nil
|
||||
}
|
||||
|
||||
parsed, err := time.Parse(microTimeFormat, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse MicroTime %q: %w", s, err)
|
||||
}
|
||||
t.Time = parsed
|
||||
return nil
|
||||
}
|
||||
|
||||
// LeaseClient talks to the Kubernetes coordination API using raw HTTP.
|
||||
type LeaseClient struct {
|
||||
baseURL string
|
||||
namespace string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewLeaseClient creates a client that authenticates via the pod's
|
||||
// mounted service account. It reads the namespace and CA certificate
|
||||
// at construction time (they don't rotate) but reads the bearer token
|
||||
// fresh on each request (tokens rotate).
|
||||
func NewLeaseClient() (*LeaseClient, error) {
|
||||
host := os.Getenv("KUBERNETES_SERVICE_HOST")
|
||||
port := os.Getenv("KUBERNETES_SERVICE_PORT")
|
||||
if host == "" || port == "" {
|
||||
return nil, fmt.Errorf("KUBERNETES_SERVICE_HOST/PORT not set")
|
||||
}
|
||||
|
||||
ns, err := os.ReadFile(saNamespacePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read namespace from %s: %w", saNamespacePath, err)
|
||||
}
|
||||
|
||||
caCert, err := os.ReadFile(saCACertPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read CA cert from %s: %w", saCACertPath, err)
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(caCert) {
|
||||
return nil, fmt.Errorf("parse CA certificate from %s", saCACertPath)
|
||||
}
|
||||
|
||||
return &LeaseClient{
|
||||
baseURL: fmt.Sprintf("https://%s:%s", host, port),
|
||||
namespace: strings.TrimSpace(string(ns)),
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: pool,
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Namespace returns the namespace this client operates in.
|
||||
func (c *LeaseClient) Namespace() string {
|
||||
return c.namespace
|
||||
}
|
||||
|
||||
// Get retrieves a Lease by name. Returns (nil, nil) if the Lease does not exist.
|
||||
func (c *LeaseClient) Get(ctx context.Context, name string) (*Lease, error) {
|
||||
url := fmt.Sprintf("%s%s/namespaces/%s/leases/%s", c.baseURL, leaseAPIPath, c.namespace, name)
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, c.readError(resp)
|
||||
}
|
||||
|
||||
var lease Lease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&lease); err != nil {
|
||||
return nil, fmt.Errorf("decode lease response: %w", err)
|
||||
}
|
||||
return &lease, nil
|
||||
}
|
||||
|
||||
// Create creates a new Lease. Returns the created Lease with server-assigned
|
||||
// fields like resourceVersion populated.
|
||||
func (c *LeaseClient) Create(ctx context.Context, lease *Lease) (*Lease, error) {
|
||||
url := fmt.Sprintf("%s%s/namespaces/%s/leases", c.baseURL, leaseAPIPath, c.namespace)
|
||||
|
||||
lease.APIVersion = "coordination.k8s.io/v1"
|
||||
lease.Kind = "Lease"
|
||||
if lease.Metadata.Namespace == "" {
|
||||
lease.Metadata.Namespace = c.namespace
|
||||
}
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodPost, url, lease)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusConflict {
|
||||
return nil, ErrConflict
|
||||
}
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
return nil, c.readError(resp)
|
||||
}
|
||||
|
||||
var created Lease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&created); err != nil {
|
||||
return nil, fmt.Errorf("decode created lease: %w", err)
|
||||
}
|
||||
return &created, nil
|
||||
}
|
||||
|
||||
// Update replaces a Lease. The lease.Metadata.ResourceVersion must match
|
||||
// the current server value (optimistic concurrency). Returns ErrConflict
|
||||
// on version mismatch.
|
||||
func (c *LeaseClient) Update(ctx context.Context, lease *Lease) (*Lease, error) {
|
||||
url := fmt.Sprintf("%s%s/namespaces/%s/leases/%s", c.baseURL, leaseAPIPath, c.namespace, lease.Metadata.Name)
|
||||
|
||||
lease.APIVersion = "coordination.k8s.io/v1"
|
||||
lease.Kind = "Lease"
|
||||
|
||||
resp, err := c.doRequest(ctx, http.MethodPut, url, lease)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusConflict {
|
||||
return nil, ErrConflict
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, c.readError(resp)
|
||||
}
|
||||
|
||||
var updated Lease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&updated); err != nil {
|
||||
return nil, fmt.Errorf("decode updated lease: %w", err)
|
||||
}
|
||||
return &updated, nil
|
||||
}
|
||||
|
||||
func (c *LeaseClient) doRequest(ctx context.Context, method, url string, body any) (*http.Response, error) {
|
||||
token, err := readToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read service account token: %w", err)
|
||||
}
|
||||
|
||||
var bodyReader io.Reader
|
||||
if body != nil {
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request body: %w", err)
|
||||
}
|
||||
bodyReader = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, bodyReader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Accept", "application/json")
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
return c.httpClient.Do(req)
|
||||
}
|
||||
|
||||
func readToken() (string, error) {
|
||||
data, err := os.ReadFile(saTokenPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read %s: %w", saTokenPath, err)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
func (c *LeaseClient) readError(resp *http.Response) error {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("k8s API %s %d: %s", resp.Request.URL.Path, resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// LeaseNameForDomain returns a deterministic, DNS-label-safe Lease name
|
||||
// for the given domain. The domain is hashed to avoid dots and length issues.
|
||||
func LeaseNameForDomain(domain string) string {
|
||||
h := sha256.Sum256([]byte(domain))
|
||||
return "cert-lock-" + hex.EncodeToString(h[:8])
|
||||
}
|
||||
|
||||
// InCluster reports whether the process is running inside a Kubernetes pod
|
||||
// by checking for the KUBERNETES_SERVICE_HOST environment variable.
|
||||
func InCluster() bool {
|
||||
_, exists := os.LookupEnv("KUBERNETES_SERVICE_HOST")
|
||||
return exists
|
||||
}
|
||||
102
proxy/internal/k8s/lease_test.go
Normal file
102
proxy/internal/k8s/lease_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package k8s
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLeaseNameForDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain string
|
||||
}{
|
||||
{"example.com"},
|
||||
{"app.example.com"},
|
||||
{"another.domain.io"},
|
||||
}
|
||||
|
||||
seen := make(map[string]string)
|
||||
for _, tc := range tests {
|
||||
name := LeaseNameForDomain(tc.domain)
|
||||
|
||||
assert.True(t, len(name) <= 63, "must be valid DNS label length")
|
||||
assert.Regexp(t, `^cert-lock-[0-9a-f]{16}$`, name,
|
||||
"must match expected format for domain %q", tc.domain)
|
||||
|
||||
// Same input produces same output.
|
||||
assert.Equal(t, name, LeaseNameForDomain(tc.domain), "must be deterministic")
|
||||
|
||||
// Different domains produce different names.
|
||||
if prev, ok := seen[name]; ok {
|
||||
t.Errorf("collision: %q and %q both map to %s", prev, tc.domain, name)
|
||||
}
|
||||
seen[name] = tc.domain
|
||||
}
|
||||
}
|
||||
|
||||
func TestMicroTimeJSON(t *testing.T) {
|
||||
ts := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC)
|
||||
mt := &MicroTime{Time: ts}
|
||||
|
||||
data, err := json.Marshal(mt)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `"2024-06-15T10:30:00.000000Z"`, string(data))
|
||||
|
||||
var decoded MicroTime
|
||||
require.NoError(t, json.Unmarshal(data, &decoded))
|
||||
assert.True(t, ts.Equal(decoded.Time), "round-trip should preserve time")
|
||||
}
|
||||
|
||||
func TestMicroTimeNullJSON(t *testing.T) {
|
||||
// Null pointer serializes as JSON null via the Lease struct.
|
||||
spec := LeaseSpec{
|
||||
HolderIdentity: nil,
|
||||
AcquireTime: nil,
|
||||
RenewTime: nil,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(spec)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(data), `"acquireTime":null`)
|
||||
assert.Contains(t, string(data), `"renewTime":null`)
|
||||
}
|
||||
|
||||
func TestLeaseJSONRoundTrip(t *testing.T) {
|
||||
holder := "pod-abc"
|
||||
dur := int32(300)
|
||||
now := MicroTime{Time: time.Now().UTC().Truncate(time.Microsecond)}
|
||||
|
||||
original := Lease{
|
||||
APIVersion: "coordination.k8s.io/v1",
|
||||
Kind: "Lease",
|
||||
Metadata: LeaseMetadata{
|
||||
Name: "cert-lock-abcdef0123456789",
|
||||
Namespace: "default",
|
||||
ResourceVersion: "12345",
|
||||
Annotations: map[string]string{
|
||||
"netbird.io/domain": "app.example.com",
|
||||
},
|
||||
},
|
||||
Spec: LeaseSpec{
|
||||
HolderIdentity: &holder,
|
||||
LeaseDurationSeconds: &dur,
|
||||
AcquireTime: &now,
|
||||
RenewTime: &now,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded Lease
|
||||
require.NoError(t, json.Unmarshal(data, &decoded))
|
||||
|
||||
assert.Equal(t, original.Metadata.Name, decoded.Metadata.Name)
|
||||
assert.Equal(t, original.Metadata.ResourceVersion, decoded.Metadata.ResourceVersion)
|
||||
assert.Equal(t, *original.Spec.HolderIdentity, *decoded.Spec.HolderIdentity)
|
||||
assert.Equal(t, *original.Spec.LeaseDurationSeconds, *decoded.Spec.LeaseDurationSeconds)
|
||||
assert.True(t, original.Spec.AcquireTime.Equal(decoded.Spec.AcquireTime.Time))
|
||||
}
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
@@ -65,6 +66,8 @@ type Server struct {
|
||||
ProxyURL string
|
||||
ManagementAddress string
|
||||
CertificateDirectory string
|
||||
CertificateFile string
|
||||
CertificateKeyFile string
|
||||
GenerateACMECertificates bool
|
||||
ACMEChallengeAddress string
|
||||
ACMEDirectory string
|
||||
@@ -210,16 +213,17 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
"ServerName": s.ProxyURL,
|
||||
}).Debug("started ACME challenge server")
|
||||
} else {
|
||||
s.Logger.Debug("ACME certificates disabled, using static certificates")
|
||||
// Otherwise pull some certificates from expected locations.
|
||||
cert, err := tls.LoadX509KeyPair(
|
||||
filepath.Join(s.CertificateDirectory, "tls.crt"),
|
||||
filepath.Join(s.CertificateDirectory, "tls.key"),
|
||||
)
|
||||
s.Logger.Debug("ACME certificates disabled, using static certificates with file watching")
|
||||
certPath := filepath.Join(s.CertificateDirectory, s.CertificateFile)
|
||||
keyPath := filepath.Join(s.CertificateDirectory, s.CertificateKeyFile)
|
||||
|
||||
certWatcher, err := certwatch.NewWatcher(certPath, keyPath, s.Logger)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load provided certificate: %w", err)
|
||||
return fmt.Errorf("initialize certificate watcher: %w", err)
|
||||
}
|
||||
tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
|
||||
go certWatcher.Watch(ctx)
|
||||
|
||||
tlsConfig.GetCertificate = certWatcher.GetCertificate
|
||||
}
|
||||
|
||||
// Configure the reverse proxy using NetBird's HTTP Client Transport for proxying.
|
||||
|
||||
Reference in New Issue
Block a user