mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-03-27 17:56:36 +00:00
feat: add CLI command for encryption key rotation (#1209)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwk"
|
||||
@@ -17,8 +18,8 @@ type KeyProviderOpts struct {
|
||||
|
||||
type KeyProvider interface {
|
||||
Init(opts KeyProviderOpts) error
|
||||
LoadKey() (jwk.Key, error)
|
||||
SaveKey(key jwk.Key) error
|
||||
LoadKey(ctx context.Context) (jwk.Key, error)
|
||||
SaveKey(ctx context.Context, key jwk.Key) error
|
||||
}
|
||||
|
||||
func GetKeyProvider(db *gorm.DB, envConfig *common.EnvConfigSchema, instanceID string) (keyProvider KeyProvider, err error) {
|
||||
|
||||
@@ -33,12 +33,12 @@ func (f *KeyProviderDatabase) Init(opts KeyProviderOpts) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) {
|
||||
func (f *KeyProviderDatabase) LoadKey(ctx context.Context) (key jwk.Key, err error) {
|
||||
row := model.KV{
|
||||
Key: PrivateKeyDBKey,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
err = f.db.WithContext(ctx).First(&row).Error
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
@@ -74,7 +74,7 @@ func (f *KeyProviderDatabase) LoadKey() (key jwk.Key, err error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error {
|
||||
func (f *KeyProviderDatabase) SaveKey(ctx context.Context, key jwk.Key) error {
|
||||
// Encode the key to JSON
|
||||
data, err := EncodeJWKBytes(key)
|
||||
if err != nil {
|
||||
@@ -94,7 +94,7 @@ func (f *KeyProviderDatabase) SaveKey(key jwk.Key) error {
|
||||
Value: &encB64,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
err = f.db.
|
||||
WithContext(ctx).
|
||||
|
||||
@@ -59,7 +59,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load key when none exists
|
||||
loadedKey, err := provider.LoadKey()
|
||||
loadedKey, err := provider.LoadKey(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, loadedKey, "Expected nil key when no key exists in database")
|
||||
})
|
||||
@@ -76,11 +76,11 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save a key
|
||||
err = provider.SaveKey(key)
|
||||
err = provider.SaveKey(t.Context(), key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
loadedKey, err := provider.LoadKey(t.Context())
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, loadedKey, "Expected non-nil key when key exists in database")
|
||||
|
||||
@@ -114,7 +114,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
loadedKey, err := provider.LoadKey(t.Context())
|
||||
require.Error(t, err, "Expected error when loading key with invalid base64")
|
||||
require.ErrorContains(t, err, "not a valid base64-encoded value")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
@@ -140,7 +140,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
loadedKey, err := provider.LoadKey(t.Context())
|
||||
require.Error(t, err, "Expected error when loading key with invalid encrypted data")
|
||||
require.ErrorContains(t, err, "failed to decrypt")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
@@ -158,7 +158,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = originalProvider.SaveKey(key)
|
||||
err = originalProvider.SaveKey(t.Context(), key)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now try to load with a different KEK
|
||||
@@ -171,7 +171,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key with the wrong KEK
|
||||
loadedKey, err := differentProvider.LoadKey()
|
||||
loadedKey, err := differentProvider.LoadKey(t.Context())
|
||||
require.Error(t, err, "Expected error when loading key with wrong KEK")
|
||||
require.ErrorContains(t, err, "failed to decrypt")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
@@ -206,7 +206,7 @@ func TestKeyProviderDatabase_LoadKey(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to load the key
|
||||
loadedKey, err := provider.LoadKey()
|
||||
loadedKey, err := provider.LoadKey(t.Context())
|
||||
require.Error(t, err, "Expected error when loading invalid key data")
|
||||
require.ErrorContains(t, err, "failed to parse")
|
||||
assert.Nil(t, loadedKey, "Expected nil key when loading fails")
|
||||
@@ -233,7 +233,7 @@ func TestKeyProviderDatabase_SaveKey(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Save the key
|
||||
err = provider.SaveKey(key)
|
||||
err = provider.SaveKey(t.Context(), key)
|
||||
require.NoError(t, err, "Expected no error when saving key")
|
||||
|
||||
// Verify record exists in database
|
||||
|
||||
Reference in New Issue
Block a user