From 8f48d10d55269932c2ef880181d9586b57976362 Mon Sep 17 00:00:00 2001 From: Ingmar Stein <490610+IngmarStein@users.noreply.github.com> Date: Sun, 19 Apr 2026 14:04:22 +0200 Subject: [PATCH] feat: add TLS support for HTTP/2 server (#1429) Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com> --- backend/go.mod | 1 + backend/go.sum | 2 + .../internal/bootstrap/router_bootstrap.go | 158 +++++++++++++++++- backend/internal/common/env_config.go | 23 +++ backend/internal/common/env_config_test.go | 74 ++++++++ 5 files changed, 255 insertions(+), 3 deletions(-) diff --git a/backend/go.mod b/backend/go.mod index cf0733fa..640c0600 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -86,6 +86,7 @@ require ( github.com/disintegration/gift v1.2.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/fsnotify/fsnotify v1.8.0 // indirect github.com/fxamacker/cbor/v2 v2.9.1 // indirect github.com/gabriel-vasile/mimetype v1.4.13 // indirect github.com/gin-contrib/sse v1.1.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index b2f3f386..3181bdd5 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -101,6 +101,8 @@ github.com/emersion/go-smtp v0.24.0 h1:g6AfoF140mvW0vLNPD/LuCBLEAdlxOjIXqbIkJIS6 github.com/emersion/go-smtp v0.24.0/go.mod h1:ZtRRkbTyp2XTHCA+BmyTFTrj8xY4I+b4McvHxCU2gsQ= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M= +github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ= github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM= diff --git a/backend/internal/bootstrap/router_bootstrap.go b/backend/internal/bootstrap/router_bootstrap.go index 213c7d0e..1a49ed51 100644 --- a/backend/internal/bootstrap/router_bootstrap.go +++ b/backend/internal/bootstrap/router_bootstrap.go @@ -2,6 +2,7 @@ package bootstrap import ( "context" + "crypto/tls" "errors" "fmt" "log/slog" @@ -10,8 +11,11 @@ import ( "os" "strconv" "strings" + "sync" + "sync/atomic" "time" + "github.com/fsnotify/fsnotify" sloggin "github.com/gin-contrib/slog" "github.com/gin-gonic/gin" "go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin" @@ -110,7 +114,29 @@ func initRouter(db *gorm.DB, svc *services) (utils.Service, error) { var protocols http.Protocols protocols.SetHTTP1(true) - protocols.SetUnencryptedHTTP2(true) + + var tlsConfig *tls.Config + var certProvider *tlsCertProvider + var certWatcher *fsnotify.Watcher + + if common.EnvConfig.TLSCertFile != "" && common.EnvConfig.TLSKeyFile != "" { + protocols.SetHTTP2(true) + + certProvider, err = newCertProvider(common.EnvConfig.TLSCertFile, common.EnvConfig.TLSKeyFile) + if err != nil { + return nil, fmt.Errorf("failed to load TLS certificate: %w", err) + } + + tlsConfig = &tls.Config{ + GetCertificate: certProvider.GetCertificate, + MinVersion: tls.VersionTLS13, + NextProtos: []string{"h2"}, + } + + slog.Info("TLS enabled") + } else { + protocols.SetUnencryptedHTTP2(true) + } // Set up the server srv := &http.Server{ @@ -158,14 +184,39 @@ func initRouter(db *gorm.DB, svc *services) (utils.Service, error) { // Service runner function runFn := func(ctx context.Context) error { - slog.Info("Server listening", slog.String("addr", addr)) + slog.Info("Server listening", slog.String("addr", addr), slog.Bool("tls", tlsConfig != nil)) + + // Set up certificate hot reloading if TLS is enabled + if certProvider != nil { + certWatcher, err = fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("failed to create certificate watcher: %w", err) + } + + // Watch both certificate and key files + if err := certWatcher.Add(common.EnvConfig.TLSCertFile); err != nil { + return fmt.Errorf("failed to watch TLS certificate: %w", err) + } + if err := certWatcher.Add(common.EnvConfig.TLSKeyFile); err != nil { + return fmt.Errorf("failed to watch TLS key: %w", err) + } + + // Start certificate watcher goroutine + go certProvider.StartWatching(ctx, certWatcher) + } // Start the server in a background goroutine go func() { defer listener.Close() // Next call blocks until the server is shut down - srvErr := srv.Serve(listener) + var srvErr error + if tlsConfig != nil { + srvErr = srv.Serve(tls.NewListener(listener, tlsConfig)) + } else { + srvErr = srv.Serve(listener) + } + if srvErr != http.ErrServerClosed { slog.Error("Error starting app server", "error", srvErr) os.Exit(1) @@ -192,6 +243,11 @@ func initRouter(db *gorm.DB, svc *services) (utils.Service, error) { slog.Warn("App server shutdown error", "error", shutdownErr) } + // Close certificate watcher + if certWatcher != nil { + certWatcher.Close() + } + return nil } @@ -224,3 +280,99 @@ func initLogger(r *gin.Engine) { }), )) } + +// tlsCertProvider holds certificates that can be dynamically reloaded +type tlsCertProvider struct { + certMutex sync.RWMutex + cert *tls.Certificate + certFile string + keyFile string + forceReload atomic.Bool +} + +// GetCertificate implements tls.GetCertificate interface for dynamic certificate loading +func (p *tlsCertProvider) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + if p.forceReload.Load() { + p.certMutex.Lock() + p.forceReload.Store(false) + p.certMutex.Unlock() + } + + p.certMutex.RLock() + defer p.certMutex.RUnlock() + return p.cert, nil +} + +// newCertProvider creates a new certificate provider with initial certificates loaded +func newCertProvider(certFile, keyFile string) (*tlsCertProvider, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + return &tlsCertProvider{ + cert: &cert, + certFile: certFile, + keyFile: keyFile, + }, nil +} + +// reloadCertificate reloads the certificate from disk +func (p *tlsCertProvider) reloadCertificate() error { + cert, err := tls.LoadX509KeyPair(p.certFile, p.keyFile) + if err != nil { + return fmt.Errorf("failed to reload TLS certificate: %w", err) + } + + p.certMutex.Lock() + p.cert = &cert + p.certMutex.Unlock() + + return nil +} + +// StartWatching begins monitoring the certificate files for changes with debouncing +func (p *tlsCertProvider) StartWatching(ctx context.Context, watcher *fsnotify.Watcher) { + debounceDuration := 1 * time.Second + reloadTimer := time.NewTimer(debounceDuration) + reloadTimer.Stop() + + for { + select { + case <-ctx.Done(): + return + case event, ok := <-watcher.Events: + if !ok { + return + } + // Only process write/rename events for certificate/key files + if event.Has(fsnotify.Write | fsnotify.Rename) { + // Reset the debounce timer whenever we get a relevant event + reloadTimer.Stop() + // Drain the channel if there's a pending value + select { + case <-reloadTimer.C: + default: + } + reloadTimer.Reset(debounceDuration) + slog.Debug("TLS file change detected, debouncing", slog.String("path", event.Name)) + } + case <-reloadTimer.C: + // Timer fired - no more events in 500ms, so reload + slog.Info("Reloading TLS certificate") + + if err := p.reloadCertificate(); err != nil { + slog.Error("Failed to reload TLS certificate", "error", err) + continue + } + + p.forceReload.Store(true) + slog.Info("TLS certificate reloaded successfully") + case err, ok := <-watcher.Errors: + if !ok { + return + } + slog.Error("Certificate watcher error", "error", err) + } + } +} diff --git a/backend/internal/common/env_config.go b/backend/internal/common/env_config.go index ed443275..c8c73ca3 100644 --- a/backend/internal/common/env_config.go +++ b/backend/internal/common/env_config.go @@ -70,6 +70,9 @@ type EnvConfigSchema struct { UnixSocketMode string `env:"UNIX_SOCKET_MODE"` LocalIPv6Ranges string `env:"LOCAL_IPV6_RANGES"` + TLSCertFile string `env:"TLS_CERT" options:"file"` + TLSKeyFile string `env:"TLS_KEY" options:"file"` + MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY" options:"file"` GeoLiteDBPath string `env:"GEOLITE_DB_PATH"` GeoLiteDBUrl string `env:"GEOLITE_DB_URL"` @@ -211,6 +214,26 @@ func ValidateEnvConfig(config *EnvConfigSchema) error { return errors.New("STATIC_API_KEY must be at least 16 characters long") } + // Validate TLS config + switch { + case config.TLSCertFile != "" && config.TLSKeyFile == "": + return errors.New("TLS_KEY_FILE must be set when TLS_CERT_FILE is set") + case config.TLSCertFile == "" && config.TLSKeyFile != "": + return errors.New("TLS_CERT_FILE must be set when TLS_KEY_FILE is set") + } + + if config.TLSCertFile != "" && config.TLSKeyFile != "" { + if _, err := os.Stat(config.TLSCertFile); err != nil { + return fmt.Errorf("TLS_CERT_FILE not found: %w", err) + } + } + + if config.TLSCertFile != "" && config.TLSKeyFile != "" { + if _, err := os.Stat(config.TLSKeyFile); err != nil { + return fmt.Errorf("TLS_KEY_FILE not found: %w", err) + } + } + return nil } diff --git a/backend/internal/common/env_config_test.go b/backend/internal/common/env_config_test.go index 22044a60..b209b441 100644 --- a/backend/internal/common/env_config_test.go +++ b/backend/internal/common/env_config_test.go @@ -207,6 +207,58 @@ func TestParseEnvConfig(t *testing.T) { require.Error(t, err) assert.ErrorContains(t, err, "invalid FILE_BACKEND value") }) + + t.Run("should fail when TLS cert is set without key", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000") + t.Setenv("TLS_CERT", "/path/to/cert.pem") + + err := parseAndValidateEnvConfig(t) + require.Error(t, err) + assert.ErrorContains(t, err, "TLS_KEY_FILE must be set when TLS_CERT_FILE is set") + }) + + t.Run("should fail when TLS key is set without cert", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000") + t.Setenv("TLS_KEY", "/path/to/key.pem") + + err := parseAndValidateEnvConfig(t) + require.Error(t, err) + assert.ErrorContains(t, err, "TLS_CERT_FILE must be set when TLS_KEY_FILE is set") + }) + + t.Run("should fail when TLS cert file does not exist", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000") + t.Setenv("TLS_CERT", "/nonexistent/cert.pem") + + keyFile := t.TempDir() + "/key.pem" + require.NoError(t, os.WriteFile(keyFile, []byte("key"), 0600)) + t.Setenv("TLS_KEY", keyFile) + + err := parseAndValidateEnvConfig(t) + require.Error(t, err) + assert.ErrorContains(t, err, "TLS_CERT_FILE not found") + }) + + t.Run("should fail when TLS key file does not exist", func(t *testing.T) { + EnvConfig = defaultConfig() + t.Setenv("DB_CONNECTION_STRING", "file:test.db") + t.Setenv("APP_URL", "http://localhost:3000") + + certFile := t.TempDir() + "/cert.pem" + require.NoError(t, os.WriteFile(certFile, []byte("cert"), 0600)) + t.Setenv("TLS_CERT", certFile) + t.Setenv("TLS_KEY", "/nonexistent/key.pem") + + err := parseAndValidateEnvConfig(t) + require.Error(t, err) + assert.ErrorContains(t, err, "TLS_KEY_FILE not found") + }) } func TestPrepareEnvConfig_FileBasedAndToLower(t *testing.T) { @@ -254,4 +306,26 @@ func TestPrepareEnvConfig_FileBasedAndToLower(t *testing.T) { require.NoError(t, err) assert.Equal(t, binaryKeyContent, config.EncryptionKey) }) + + t.Run("should load TLS cert and key file contents", func(t *testing.T) { + config := defaultConfig() + + certFile := tempDir + "/cert.pem" + certContent := "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----" + err := os.WriteFile(certFile, []byte(certContent), 0600) + require.NoError(t, err) + + keyFile := tempDir + "/key.pem" + keyContent := "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----" + err = os.WriteFile(keyFile, []byte(keyContent), 0600) + require.NoError(t, err) + + t.Setenv("TLS_CERT_FILE", certFile) + t.Setenv("TLS_KEY_FILE", keyFile) + + err = prepareEnvConfig(&config) + require.NoError(t, err) + assert.Equal(t, certContent, config.TLSCertFile) + assert.Equal(t, keyContent, config.TLSKeyFile) + }) }