mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-05-14 17:09:53 +00:00
feat: add TLS support for HTTP/2 server (#1429)
Co-authored-by: Alessandro (Ale) Segala <43508+ItalyPaleAle@users.noreply.github.com>
This commit is contained in:
@@ -86,6 +86,7 @@ require (
|
||||
github.com/disintegration/gift v1.2.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/fsnotify/fsnotify v1.8.0 // indirect
|
||||
github.com/fxamacker/cbor/v2 v2.9.1 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 // indirect
|
||||
github.com/gin-contrib/sse v1.1.1 // indirect
|
||||
|
||||
@@ -101,6 +101,8 @@ github.com/emersion/go-smtp v0.24.0 h1:g6AfoF140mvW0vLNPD/LuCBLEAdlxOjIXqbIkJIS6
|
||||
github.com/emersion/go-smtp v0.24.0/go.mod h1:ZtRRkbTyp2XTHCA+BmyTFTrj8xY4I+b4McvHxCU2gsQ=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
|
||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/fxamacker/cbor/v2 v2.9.1 h1:2rWm8B193Ll4VdjsJY28jxs70IdDsHRWgQYAI80+rMQ=
|
||||
github.com/fxamacker/cbor/v2 v2.9.1/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ=
|
||||
github.com/gabriel-vasile/mimetype v1.4.13 h1:46nXokslUBsAJE/wMsp5gtO500a4F3Nkz9Ufpk2AcUM=
|
||||
|
||||
@@ -2,6 +2,7 @@ package bootstrap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -10,8 +11,11 @@ import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
sloggin "github.com/gin-contrib/slog"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
|
||||
@@ -110,7 +114,29 @@ func initRouter(db *gorm.DB, svc *services) (utils.Service, error) {
|
||||
|
||||
var protocols http.Protocols
|
||||
protocols.SetHTTP1(true)
|
||||
protocols.SetUnencryptedHTTP2(true)
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
var certProvider *tlsCertProvider
|
||||
var certWatcher *fsnotify.Watcher
|
||||
|
||||
if common.EnvConfig.TLSCertFile != "" && common.EnvConfig.TLSKeyFile != "" {
|
||||
protocols.SetHTTP2(true)
|
||||
|
||||
certProvider, err = newCertProvider(common.EnvConfig.TLSCertFile, common.EnvConfig.TLSKeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load TLS certificate: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig = &tls.Config{
|
||||
GetCertificate: certProvider.GetCertificate,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
NextProtos: []string{"h2"},
|
||||
}
|
||||
|
||||
slog.Info("TLS enabled")
|
||||
} else {
|
||||
protocols.SetUnencryptedHTTP2(true)
|
||||
}
|
||||
|
||||
// Set up the server
|
||||
srv := &http.Server{
|
||||
@@ -158,14 +184,39 @@ func initRouter(db *gorm.DB, svc *services) (utils.Service, error) {
|
||||
|
||||
// Service runner function
|
||||
runFn := func(ctx context.Context) error {
|
||||
slog.Info("Server listening", slog.String("addr", addr))
|
||||
slog.Info("Server listening", slog.String("addr", addr), slog.Bool("tls", tlsConfig != nil))
|
||||
|
||||
// Set up certificate hot reloading if TLS is enabled
|
||||
if certProvider != nil {
|
||||
certWatcher, err = fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create certificate watcher: %w", err)
|
||||
}
|
||||
|
||||
// Watch both certificate and key files
|
||||
if err := certWatcher.Add(common.EnvConfig.TLSCertFile); err != nil {
|
||||
return fmt.Errorf("failed to watch TLS certificate: %w", err)
|
||||
}
|
||||
if err := certWatcher.Add(common.EnvConfig.TLSKeyFile); err != nil {
|
||||
return fmt.Errorf("failed to watch TLS key: %w", err)
|
||||
}
|
||||
|
||||
// Start certificate watcher goroutine
|
||||
go certProvider.StartWatching(ctx, certWatcher)
|
||||
}
|
||||
|
||||
// Start the server in a background goroutine
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
|
||||
// Next call blocks until the server is shut down
|
||||
srvErr := srv.Serve(listener)
|
||||
var srvErr error
|
||||
if tlsConfig != nil {
|
||||
srvErr = srv.Serve(tls.NewListener(listener, tlsConfig))
|
||||
} else {
|
||||
srvErr = srv.Serve(listener)
|
||||
}
|
||||
|
||||
if srvErr != http.ErrServerClosed {
|
||||
slog.Error("Error starting app server", "error", srvErr)
|
||||
os.Exit(1)
|
||||
@@ -192,6 +243,11 @@ func initRouter(db *gorm.DB, svc *services) (utils.Service, error) {
|
||||
slog.Warn("App server shutdown error", "error", shutdownErr)
|
||||
}
|
||||
|
||||
// Close certificate watcher
|
||||
if certWatcher != nil {
|
||||
certWatcher.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -224,3 +280,99 @@ func initLogger(r *gin.Engine) {
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
// tlsCertProvider holds certificates that can be dynamically reloaded
|
||||
type tlsCertProvider struct {
|
||||
certMutex sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
certFile string
|
||||
keyFile string
|
||||
forceReload atomic.Bool
|
||||
}
|
||||
|
||||
// GetCertificate implements tls.GetCertificate interface for dynamic certificate loading
|
||||
func (p *tlsCertProvider) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if p.forceReload.Load() {
|
||||
p.certMutex.Lock()
|
||||
p.forceReload.Store(false)
|
||||
p.certMutex.Unlock()
|
||||
}
|
||||
|
||||
p.certMutex.RLock()
|
||||
defer p.certMutex.RUnlock()
|
||||
return p.cert, nil
|
||||
}
|
||||
|
||||
// newCertProvider creates a new certificate provider with initial certificates loaded
|
||||
func newCertProvider(certFile, keyFile string) (*tlsCertProvider, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tlsCertProvider{
|
||||
cert: &cert,
|
||||
certFile: certFile,
|
||||
keyFile: keyFile,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// reloadCertificate reloads the certificate from disk
|
||||
func (p *tlsCertProvider) reloadCertificate() error {
|
||||
cert, err := tls.LoadX509KeyPair(p.certFile, p.keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to reload TLS certificate: %w", err)
|
||||
}
|
||||
|
||||
p.certMutex.Lock()
|
||||
p.cert = &cert
|
||||
p.certMutex.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartWatching begins monitoring the certificate files for changes with debouncing
|
||||
func (p *tlsCertProvider) StartWatching(ctx context.Context, watcher *fsnotify.Watcher) {
|
||||
debounceDuration := 1 * time.Second
|
||||
reloadTimer := time.NewTimer(debounceDuration)
|
||||
reloadTimer.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case event, ok := <-watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Only process write/rename events for certificate/key files
|
||||
if event.Has(fsnotify.Write | fsnotify.Rename) {
|
||||
// Reset the debounce timer whenever we get a relevant event
|
||||
reloadTimer.Stop()
|
||||
// Drain the channel if there's a pending value
|
||||
select {
|
||||
case <-reloadTimer.C:
|
||||
default:
|
||||
}
|
||||
reloadTimer.Reset(debounceDuration)
|
||||
slog.Debug("TLS file change detected, debouncing", slog.String("path", event.Name))
|
||||
}
|
||||
case <-reloadTimer.C:
|
||||
// Timer fired - no more events in 500ms, so reload
|
||||
slog.Info("Reloading TLS certificate")
|
||||
|
||||
if err := p.reloadCertificate(); err != nil {
|
||||
slog.Error("Failed to reload TLS certificate", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
p.forceReload.Store(true)
|
||||
slog.Info("TLS certificate reloaded successfully")
|
||||
case err, ok := <-watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
slog.Error("Certificate watcher error", "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +70,9 @@ type EnvConfigSchema struct {
|
||||
UnixSocketMode string `env:"UNIX_SOCKET_MODE"`
|
||||
LocalIPv6Ranges string `env:"LOCAL_IPV6_RANGES"`
|
||||
|
||||
TLSCertFile string `env:"TLS_CERT" options:"file"`
|
||||
TLSKeyFile string `env:"TLS_KEY" options:"file"`
|
||||
|
||||
MaxMindLicenseKey string `env:"MAXMIND_LICENSE_KEY" options:"file"`
|
||||
GeoLiteDBPath string `env:"GEOLITE_DB_PATH"`
|
||||
GeoLiteDBUrl string `env:"GEOLITE_DB_URL"`
|
||||
@@ -211,6 +214,26 @@ func ValidateEnvConfig(config *EnvConfigSchema) error {
|
||||
return errors.New("STATIC_API_KEY must be at least 16 characters long")
|
||||
}
|
||||
|
||||
// Validate TLS config
|
||||
switch {
|
||||
case config.TLSCertFile != "" && config.TLSKeyFile == "":
|
||||
return errors.New("TLS_KEY_FILE must be set when TLS_CERT_FILE is set")
|
||||
case config.TLSCertFile == "" && config.TLSKeyFile != "":
|
||||
return errors.New("TLS_CERT_FILE must be set when TLS_KEY_FILE is set")
|
||||
}
|
||||
|
||||
if config.TLSCertFile != "" && config.TLSKeyFile != "" {
|
||||
if _, err := os.Stat(config.TLSCertFile); err != nil {
|
||||
return fmt.Errorf("TLS_CERT_FILE not found: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if config.TLSCertFile != "" && config.TLSKeyFile != "" {
|
||||
if _, err := os.Stat(config.TLSKeyFile); err != nil {
|
||||
return fmt.Errorf("TLS_KEY_FILE not found: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
@@ -207,6 +207,58 @@ func TestParseEnvConfig(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "invalid FILE_BACKEND value")
|
||||
})
|
||||
|
||||
t.Run("should fail when TLS cert is set without key", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("TLS_CERT", "/path/to/cert.pem")
|
||||
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "TLS_KEY_FILE must be set when TLS_CERT_FILE is set")
|
||||
})
|
||||
|
||||
t.Run("should fail when TLS key is set without cert", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("TLS_KEY", "/path/to/key.pem")
|
||||
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "TLS_CERT_FILE must be set when TLS_KEY_FILE is set")
|
||||
})
|
||||
|
||||
t.Run("should fail when TLS cert file does not exist", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
t.Setenv("TLS_CERT", "/nonexistent/cert.pem")
|
||||
|
||||
keyFile := t.TempDir() + "/key.pem"
|
||||
require.NoError(t, os.WriteFile(keyFile, []byte("key"), 0600))
|
||||
t.Setenv("TLS_KEY", keyFile)
|
||||
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "TLS_CERT_FILE not found")
|
||||
})
|
||||
|
||||
t.Run("should fail when TLS key file does not exist", func(t *testing.T) {
|
||||
EnvConfig = defaultConfig()
|
||||
t.Setenv("DB_CONNECTION_STRING", "file:test.db")
|
||||
t.Setenv("APP_URL", "http://localhost:3000")
|
||||
|
||||
certFile := t.TempDir() + "/cert.pem"
|
||||
require.NoError(t, os.WriteFile(certFile, []byte("cert"), 0600))
|
||||
t.Setenv("TLS_CERT", certFile)
|
||||
t.Setenv("TLS_KEY", "/nonexistent/key.pem")
|
||||
|
||||
err := parseAndValidateEnvConfig(t)
|
||||
require.Error(t, err)
|
||||
assert.ErrorContains(t, err, "TLS_KEY_FILE not found")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareEnvConfig_FileBasedAndToLower(t *testing.T) {
|
||||
@@ -254,4 +306,26 @@ func TestPrepareEnvConfig_FileBasedAndToLower(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, binaryKeyContent, config.EncryptionKey)
|
||||
})
|
||||
|
||||
t.Run("should load TLS cert and key file contents", func(t *testing.T) {
|
||||
config := defaultConfig()
|
||||
|
||||
certFile := tempDir + "/cert.pem"
|
||||
certContent := "-----BEGIN CERTIFICATE-----\ntest\n-----END CERTIFICATE-----"
|
||||
err := os.WriteFile(certFile, []byte(certContent), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
keyFile := tempDir + "/key.pem"
|
||||
keyContent := "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----"
|
||||
err = os.WriteFile(keyFile, []byte(keyContent), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Setenv("TLS_CERT_FILE", certFile)
|
||||
t.Setenv("TLS_KEY_FILE", keyFile)
|
||||
|
||||
err = prepareEnvConfig(&config)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, certContent, config.TLSCertFile)
|
||||
assert.Equal(t, keyContent, config.TLSKeyFile)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user