feat: add CLI command for encryption key rotation (#1209)

This commit is contained in:
Elias Schneider
2026-01-07 09:34:23 +01:00
committed by GitHub
parent 5828fa5779
commit 2af70d9b4d
13 changed files with 340 additions and 42 deletions

View File

@@ -0,0 +1,187 @@
package cmds
import (
"context"
"errors"
"fmt"
"os"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/spf13/cobra"
"gorm.io/gorm"
"github.com/pocket-id/pocket-id/backend/internal/bootstrap"
"github.com/pocket-id/pocket-id/backend/internal/common"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/service"
"github.com/pocket-id/pocket-id/backend/internal/utils"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
)
type encryptionKeyRotateFlags struct {
NewKey string
Yes bool
}
func init() {
var flags encryptionKeyRotateFlags
encryptionKeyRotateCmd := &cobra.Command{
Use: "encryption-key-rotate",
Short: "Re-encrypts data using a new encryption key",
RunE: func(cmd *cobra.Command, args []string) error {
db, err := bootstrap.NewDatabase()
if err != nil {
return err
}
return encryptionKeyRotate(cmd.Context(), flags, db, &common.EnvConfig)
},
}
encryptionKeyRotateCmd.Flags().StringVar(&flags.NewKey, "new-key", "", "New encryption key to re-encrypt data with")
encryptionKeyRotateCmd.Flags().BoolVarP(&flags.Yes, "yes", "y", false, "Do not prompt for confirmation")
rootCmd.AddCommand(encryptionKeyRotateCmd)
}
func encryptionKeyRotate(ctx context.Context, flags encryptionKeyRotateFlags, db *gorm.DB, envConfig *common.EnvConfigSchema) error {
oldKey := envConfig.EncryptionKey
newKey := []byte(flags.NewKey)
if len(newKey) == 0 {
return errors.New("new encryption key is required (--new-key)")
}
if len(newKey) < 16 {
return errors.New("new encryption key must be at least 16 bytes long")
}
if !flags.Yes {
fmt.Println("WARNING: Rotating the encryption key will re-encrypt secrets in the database. Pocket-ID must be restarted with the new ENCRYPTION_KEY after rotation is complete.")
ok, err := utils.PromptForConfirmation("Continue")
if err != nil {
return err
}
if !ok {
fmt.Println("Aborted")
os.Exit(1)
}
}
appConfigService, err := service.NewAppConfigService(ctx, db)
if err != nil {
return fmt.Errorf("failed to create app config service: %w", err)
}
instanceID := appConfigService.GetDbConfig().InstanceID.Value
// Derive the encryption keys used for the JWK encryption
oldKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: oldKey}, instanceID)
if err != nil {
return fmt.Errorf("failed to derive old key encryption key: %w", err)
}
newKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: newKey}, instanceID)
if err != nil {
return fmt.Errorf("failed to derive new key encryption key: %w", err)
}
// Derive the encryption keys used for EncryptedString fields
oldEncKey, err := datatype.DeriveEncryptedStringKey(oldKey)
if err != nil {
return fmt.Errorf("failed to derive old encrypted string key: %w", err)
}
newEncKey, err := datatype.DeriveEncryptedStringKey(newKey)
if err != nil {
return fmt.Errorf("failed to derive new encrypted string key: %w", err)
}
err = db.Transaction(func(tx *gorm.DB) error {
err = rotateSigningKeyEncryption(ctx, tx, oldKek, newKek)
if err != nil {
return err
}
err = rotateScimTokens(tx, oldEncKey, newEncKey)
if err != nil {
return err
}
return nil
})
if err != nil {
return err
}
fmt.Println("Encryption key rotation completed successfully.")
fmt.Println("Restart pocket-id with the new ENCRYPTION_KEY to use the rotated data.")
return nil
}
func rotateSigningKeyEncryption(ctx context.Context, db *gorm.DB, oldKek []byte, newKek []byte) error {
oldProvider := &jwkutils.KeyProviderDatabase{}
err := oldProvider.Init(jwkutils.KeyProviderOpts{
DB: db,
Kek: oldKek,
})
if err != nil {
return fmt.Errorf("failed to init key provider with old encryption key: %w", err)
}
key, err := oldProvider.LoadKey(ctx)
if err != nil {
return fmt.Errorf("failed to load signing key using old encryption key: %w", err)
}
if key == nil {
return nil
}
newProvider := &jwkutils.KeyProviderDatabase{}
err = newProvider.Init(jwkutils.KeyProviderOpts{
DB: db,
Kek: newKek,
})
if err != nil {
return fmt.Errorf("failed to init key provider with new encryption key: %w", err)
}
if err := newProvider.SaveKey(ctx, key); err != nil {
return fmt.Errorf("failed to store signing key with new encryption key: %w", err)
}
return nil
}
type scimTokenRow struct {
ID string
Token string
}
func rotateScimTokens(db *gorm.DB, oldEncKey []byte, newEncKey []byte) error {
var rows []scimTokenRow
err := db.Model(&model.ScimServiceProvider{}).Select("id, token").Scan(&rows).Error
if err != nil {
return fmt.Errorf("failed to list SCIM service providers: %w", err)
}
for _, row := range rows {
if row.Token == "" {
continue
}
decBytes, err := datatype.DecryptEncryptedStringWithKey(oldEncKey, row.Token)
if err != nil {
return fmt.Errorf("failed to decrypt SCIM token for provider %s: %w", row.ID, err)
}
encValue, err := datatype.EncryptEncryptedStringWithKey(newEncKey, decBytes)
if err != nil {
return fmt.Errorf("failed to encrypt SCIM token for provider %s: %w", row.ID, err)
}
err = db.Model(&model.ScimServiceProvider{}).Where("id = ?", row.ID).Update("token", encValue).Error
if err != nil {
return fmt.Errorf("failed to update SCIM token for provider %s: %w", row.ID, err)
}
}
return nil
}

View File

@@ -0,0 +1,89 @@
package cmds
import (
"testing"
"time"
"github.com/pocket-id/pocket-id/backend/internal/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pocket-id/pocket-id/backend/internal/common"
datatype "github.com/pocket-id/pocket-id/backend/internal/model/types"
"github.com/pocket-id/pocket-id/backend/internal/service"
jwkutils "github.com/pocket-id/pocket-id/backend/internal/utils/jwk"
testingutils "github.com/pocket-id/pocket-id/backend/internal/utils/testing"
)
func TestEncryptionKeyRotate(t *testing.T) {
oldKey := []byte("old-encryption-key-123456")
newKey := []byte("new-encryption-key-654321")
envConfig := &common.EnvConfigSchema{
EncryptionKey: oldKey,
}
db := testingutils.NewDatabaseForTest(t)
appConfigService, err := service.NewAppConfigService(t.Context(), db)
require.NoError(t, err)
instanceID := appConfigService.GetDbConfig().InstanceID.Value
oldKek, err := jwkutils.LoadKeyEncryptionKey(envConfig, instanceID)
require.NoError(t, err)
oldProvider := &jwkutils.KeyProviderDatabase{}
require.NoError(t, oldProvider.Init(jwkutils.KeyProviderOpts{
DB: db,
Kek: oldKek,
}))
signingKey, err := jwkutils.GenerateKey("RS256", "")
require.NoError(t, err)
require.NoError(t, oldProvider.SaveKey(t.Context(), signingKey))
oldEncKey, err := datatype.DeriveEncryptedStringKey(oldKey)
require.NoError(t, err)
encToken, err := datatype.EncryptEncryptedStringWithKey(oldEncKey, []byte("scim-token-123"))
require.NoError(t, err)
err = db.Exec(
`INSERT INTO scim_service_providers (id, created_at, endpoint, token, oidc_client_id) VALUES (?, ?, ?, ?, ?)`,
"scim-1",
time.Now(),
"https://example.com/scim",
encToken,
"client-1",
).Error
require.NoError(t, err)
flags := encryptionKeyRotateFlags{
NewKey: string(newKey),
Yes: true,
}
require.NoError(t, encryptionKeyRotate(t.Context(), flags, db, envConfig))
newKek, err := jwkutils.LoadKeyEncryptionKey(&common.EnvConfigSchema{EncryptionKey: newKey}, instanceID)
require.NoError(t, err)
newProvider := &jwkutils.KeyProviderDatabase{}
require.NoError(t, newProvider.Init(jwkutils.KeyProviderOpts{
DB: db,
Kek: newKek,
}))
rotatedKey, err := newProvider.LoadKey(t.Context())
require.NoError(t, err)
require.NotNil(t, rotatedKey)
var storedToken string
err = db.Model(&model.ScimServiceProvider{}).Where("id = ?", "scim-1").Pluck("token", &storedToken).Error
require.NoError(t, err)
newEncKey, err := datatype.DeriveEncryptedStringKey(newKey)
require.NoError(t, err)
decBytes, err := datatype.DecryptEncryptedStringWithKey(newEncKey, storedToken)
require.NoError(t, err)
assert.Equal(t, "scim-token-123", string(decBytes))
}

View File

@@ -102,7 +102,7 @@ func keyRotate(ctx context.Context, flags keyRotateFlags, db *gorm.DB, envConfig
}
// Save the key
err = keyProvider.SaveKey(key)
err = keyProvider.SaveKey(ctx, key)
if err != nil {
return fmt.Errorf("failed to store new key: %w", err)
}

View File

@@ -104,7 +104,7 @@ func testKeyRotateWithDatabaseStorage(t *testing.T, flags keyRotateFlags, wantEr
require.NoError(t, err)
// Verify key was created
key, err := keyProvider.LoadKey()
key, err := keyProvider.LoadKey(t.Context())
require.NoError(t, err)
require.NotNil(t, key)