mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[management] Add postgres support for activity event store (#3890)
This commit is contained in:
136
management/server/activity/store/crypt.go
Normal file
136
management/server/activity/store/crypt.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
|
||||
|
||||
type FieldEncrypt struct {
|
||||
block cipher.Block
|
||||
gcm cipher.AEAD
|
||||
}
|
||||
|
||||
func GenerateKey() (string, error) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
readableKey := base64.StdEncoding.EncodeToString(key)
|
||||
return readableKey, nil
|
||||
}
|
||||
|
||||
func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
|
||||
binKey, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(binKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ec := &FieldEncrypt{
|
||||
block: block,
|
||||
gcm: gcm,
|
||||
}
|
||||
|
||||
return ec, nil
|
||||
}
|
||||
|
||||
func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
|
||||
plainText := pkcs5Padding([]byte(payload))
|
||||
cipherText := make([]byte, len(plainText))
|
||||
cbc := cipher.NewCBCEncrypter(ec.block, iv)
|
||||
cbc.CryptBlocks(cipherText, plainText)
|
||||
return base64.StdEncoding.EncodeToString(cipherText)
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext using AES-GCM
|
||||
func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
|
||||
plaintext := []byte(payload)
|
||||
nonceSize := ec.gcm.NonceSize()
|
||||
|
||||
nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
|
||||
cipherText, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cbc := cipher.NewCBCDecrypter(ec.block, iv)
|
||||
cbc.CryptBlocks(cipherText, cipherText)
|
||||
payload, err := pkcs5UnPadding(cipherText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(payload), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext using AES-GCM
|
||||
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
|
||||
cipherText, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := ec.gcm.NonceSize()
|
||||
if len(cipherText) < nonceSize {
|
||||
return "", errors.New("cipher text too short")
|
||||
}
|
||||
|
||||
nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
|
||||
plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plainText), nil
|
||||
}
|
||||
|
||||
func pkcs5Padding(ciphertext []byte) []byte {
|
||||
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(ciphertext, padText...)
|
||||
}
|
||||
func pkcs5UnPadding(src []byte) ([]byte, error) {
|
||||
srcLen := len(src)
|
||||
if srcLen == 0 {
|
||||
return nil, errors.New("input data is empty")
|
||||
}
|
||||
|
||||
paddingLen := int(src[srcLen-1])
|
||||
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
|
||||
return nil, errors.New("invalid padding size")
|
||||
}
|
||||
|
||||
// Verify that all padding bytes are the same
|
||||
for i := 0; i < paddingLen; i++ {
|
||||
if src[srcLen-1-i] != byte(paddingLen) {
|
||||
return nil, errors.New("invalid padding")
|
||||
}
|
||||
}
|
||||
|
||||
return src[:srcLen-paddingLen], nil
|
||||
}
|
||||
310
management/server/activity/store/crypt_test.go
Normal file
310
management/server/activity/store/crypt_test.go
Normal file
@@ -0,0 +1,310 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
ee, err := NewFieldEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
encrypted, err := ee.Encrypt(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to encrypt data: %s", err)
|
||||
}
|
||||
|
||||
if encrypted == "" {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
|
||||
decrypted, err := ee.Decrypt(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decrypt data: %s", err)
|
||||
}
|
||||
|
||||
if decrypted != testData {
|
||||
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateKeyLegacy(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
ee, err := NewFieldEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
encrypted := ee.LegacyEncrypt(testData)
|
||||
if encrypted == "" {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
|
||||
decrypted, err := ee.LegacyDecrypt(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decrypt data: %s", err)
|
||||
}
|
||||
|
||||
if decrypted != testData {
|
||||
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorruptKey(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
ee, err := NewFieldEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
encrypted, err := ee.Encrypt(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to encrypt data: %s", err)
|
||||
}
|
||||
|
||||
if encrypted == "" {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
|
||||
newKey, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
|
||||
ee, err = NewFieldEncrypt(newKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
res, _ := ee.Decrypt(encrypted)
|
||||
if res == testData {
|
||||
t.Fatalf("incorrect decryption, the result is: %s", res)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// Generate a key for encryption/decryption
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate key: %v", err)
|
||||
}
|
||||
|
||||
// Initialize the FieldEncrypt with the generated key
|
||||
ec, err := NewFieldEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create FieldEncrypt: %v", err)
|
||||
}
|
||||
|
||||
// Test cases
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "Empty String",
|
||||
input: "",
|
||||
},
|
||||
{
|
||||
name: "Short String",
|
||||
input: "Hello",
|
||||
},
|
||||
{
|
||||
name: "String with Spaces",
|
||||
input: "Hello, World!",
|
||||
},
|
||||
{
|
||||
name: "Long String",
|
||||
input: "The quick brown fox jumps over the lazy dog.",
|
||||
},
|
||||
{
|
||||
name: "Unicode Characters",
|
||||
input: "こんにちは世界",
|
||||
},
|
||||
{
|
||||
name: "Special Characters",
|
||||
input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
|
||||
},
|
||||
{
|
||||
name: "Numeric String",
|
||||
input: "1234567890",
|
||||
},
|
||||
{
|
||||
name: "Repeated Characters",
|
||||
input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
},
|
||||
{
|
||||
name: "Multi-block String",
|
||||
input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
|
||||
},
|
||||
{
|
||||
name: "Non-ASCII and ASCII Mix",
|
||||
input: "Hello 世界 123",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name+" - Legacy", func(t *testing.T) {
|
||||
// Legacy Encryption
|
||||
encryptedLegacy := ec.LegacyEncrypt(tc.input)
|
||||
if encryptedLegacy == "" {
|
||||
t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
|
||||
}
|
||||
|
||||
// Legacy Decryption
|
||||
decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
|
||||
if err != nil {
|
||||
t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
|
||||
}
|
||||
|
||||
// Verify that the decrypted value matches the original input
|
||||
if decryptedLegacy != tc.input {
|
||||
t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(tc.name+" - New", func(t *testing.T) {
|
||||
// New Encryption
|
||||
encryptedNew, err := ec.Encrypt(tc.input)
|
||||
if err != nil {
|
||||
t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
|
||||
}
|
||||
if encryptedNew == "" {
|
||||
t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
|
||||
}
|
||||
|
||||
// New Decryption
|
||||
decryptedNew, err := ec.Decrypt(encryptedNew)
|
||||
if err != nil {
|
||||
t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
|
||||
}
|
||||
|
||||
// Verify that the decrypted value matches the original input
|
||||
if decryptedNew != tc.input {
|
||||
t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPKCS5UnPadding(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected []byte
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Padding",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
|
||||
expected: []byte("Hello, World!"),
|
||||
},
|
||||
{
|
||||
name: "Empty Input",
|
||||
input: []byte{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Zero",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Exceeds Block Size",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Exceeds Input Length",
|
||||
input: []byte{5, 5, 5},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Padding Bytes",
|
||||
input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid Single Byte Padding",
|
||||
input: append([]byte("Hello, World!"), byte(1)),
|
||||
expected: []byte("Hello, World!"),
|
||||
},
|
||||
{
|
||||
name: "Invalid Mixed Padding Bytes",
|
||||
input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid Full Block Padding",
|
||||
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
|
||||
expected: []byte("Hello, World!"),
|
||||
},
|
||||
{
|
||||
name: "Non-Padding Byte at End",
|
||||
input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Valid Padding with Different Text Length",
|
||||
input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
|
||||
expected: []byte("Test"),
|
||||
},
|
||||
{
|
||||
name: "Padding Length Equal to Input Length",
|
||||
input: bytes.Repeat([]byte{8}, 8),
|
||||
expected: []byte{},
|
||||
},
|
||||
{
|
||||
name: "Invalid Padding Length Zero (Again)",
|
||||
input: append([]byte("Test"), byte(0)),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Padding Length Greater Than Input",
|
||||
input: []byte{10},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Input Length Not Multiple of Block Size",
|
||||
input: append([]byte("Invalid Length"), byte(1)),
|
||||
expected: []byte("Invalid Length"),
|
||||
},
|
||||
{
|
||||
name: "Valid Padding with Non-ASCII Characters",
|
||||
input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
|
||||
expected: []byte("こんにちは"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := pkcs5UnPadding(tt.input)
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Did not expect error but got: %v", err)
|
||||
}
|
||||
if !bytes.Equal(result, tt.expected) {
|
||||
t.Errorf("Expected output %v, got %v", tt.expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
185
management/server/activity/store/migration.go
Normal file
185
management/server/activity/store/migration.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
)
|
||||
|
||||
func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error {
|
||||
migrations := getMigrations(ctx, crypt)
|
||||
|
||||
for _, m := range migrations {
|
||||
if err := m(db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type migrationFunc func(*gorm.DB) error
|
||||
|
||||
func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc {
|
||||
return []migrationFunc{
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[activity.DeletedUser](ctx, db, "name", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.MigrateNewField[activity.DeletedUser](ctx, db, "enc_algo", "")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migrateLegacyEncryptedUsersToGCM(ctx, db, crypt)
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migrateDuplicateDeletedUsers(ctx, db)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using
|
||||
// legacy CBC encryption with a static IV to the new GCM encryption method.
|
||||
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *FieldEncrypt) error {
|
||||
model := &activity.DeletedUser{}
|
||||
|
||||
if !db.Migrator().HasTable(model) {
|
||||
log.WithContext(ctx).Debugf("Table for %T does not exist, no CBC to GCM migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
var deletedUsers []activity.DeletedUser
|
||||
err := db.Model(model).Find(&deletedUsers, "enc_algo IS NULL OR enc_algo != ?", gcmEncAlgo).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query deleted_users: %w", err)
|
||||
}
|
||||
|
||||
if len(deletedUsers) == 0 {
|
||||
log.WithContext(ctx).Debug("No CBC encrypted deleted users to migrate")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = db.Transaction(func(tx *gorm.DB) error {
|
||||
for _, user := range deletedUsers {
|
||||
if err = updateDeletedUserData(tx, user, crypt); err != nil {
|
||||
return fmt.Errorf("failed to migrate deleted user %s: %w", user.ID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *FieldEncrypt) error {
|
||||
var err error
|
||||
var decryptedEmail, decryptedName string
|
||||
|
||||
if user.Email != "" {
|
||||
decryptedEmail, err = crypt.LegacyDecrypt(user.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if user.Name != "" {
|
||||
decryptedName, err = crypt.LegacyDecrypt(user.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
updatedUser := user
|
||||
updatedUser.EncAlgo = gcmEncAlgo
|
||||
|
||||
updatedUser.Email, err = crypt.Encrypt(decryptedEmail)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt email: %w", err)
|
||||
}
|
||||
|
||||
updatedUser.Name, err = crypt.Encrypt(decryptedName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt name: %w", err)
|
||||
}
|
||||
|
||||
return transaction.Model(&updatedUser).Omit("id").Updates(updatedUser).Error
|
||||
}
|
||||
|
||||
// MigrateDuplicateDeletedUsers removes duplicates and ensures the id column is marked as the primary key
|
||||
func migrateDuplicateDeletedUsers(ctx context.Context, db *gorm.DB) error {
|
||||
model := &activity.DeletedUser{}
|
||||
if !db.Migrator().HasTable(model) {
|
||||
log.WithContext(ctx).Debugf("Table for %T does not exist, no duplicate migration needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
isPrimaryKey, err := isColumnPrimaryKey[activity.DeletedUser](db, "id")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if isPrimaryKey {
|
||||
log.WithContext(ctx).Debug("No duplicate deleted users to migrate")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = db.Transaction(func(tx *gorm.DB) error {
|
||||
if err = tx.Migrator().RenameTable("deleted_users", "deleted_users_old"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.Migrator().CreateTable(model); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var deletedUsers []activity.DeletedUser
|
||||
if err = tx.Table("deleted_users_old").Find(&deletedUsers).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, deletedUser := range deletedUsers {
|
||||
if err = tx.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"email", "name", "enc_algo"}),
|
||||
}).Create(&deletedUser).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Migrator().DropTable("deleted_users_old")
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debug("Successfully migrated duplicate deleted users")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isColumnPrimaryKey checks if a column is a primary key in the given model
|
||||
func isColumnPrimaryKey[T any](db *gorm.DB, columnName string) (bool, error) {
|
||||
var model T
|
||||
|
||||
cols, err := db.Migrator().ColumnTypes(&model)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, col := range cols {
|
||||
if col.Name() == columnName {
|
||||
isPrimaryKey, _ := col.PrimaryKey()
|
||||
return isPrimaryKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
143
management/server/activity/store/migration_test.go
Normal file
143
management/server/activity/store/migration_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/migration"
|
||||
"github.com/netbirdio/netbird/management/server/testutil"
|
||||
)
|
||||
|
||||
const (
|
||||
insertDeletedUserQuery = `INSERT INTO deleted_users (id, email, name, enc_algo) VALUES (?, ?, ?, ?)`
|
||||
)
|
||||
|
||||
func setupDatabase(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
cleanup, dsn, err := testutil.CreatePostgresTestContainer()
|
||||
require.NoError(t, err, "Failed to create Postgres test container")
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn))
|
||||
require.NoError(t, err)
|
||||
|
||||
sql, err := db.DB()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = sql.Close()
|
||||
})
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestMigrateLegacyEncryptedUsersToGCM(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
key, err := GenerateKey()
|
||||
require.NoError(t, err, "Failed to generate key")
|
||||
|
||||
crypt, err := NewFieldEncrypt(key)
|
||||
require.NoError(t, err, "Failed to initialize FieldEncrypt")
|
||||
|
||||
t.Run("empty table, no migration required", func(t *testing.T) {
|
||||
require.NoError(t, migrateLegacyEncryptedUsersToGCM(context.Background(), db, crypt))
|
||||
assert.False(t, db.Migrator().HasTable("deleted_users"))
|
||||
})
|
||||
|
||||
require.NoError(t, db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`).Error)
|
||||
assert.True(t, db.Migrator().HasTable("deleted_users"))
|
||||
assert.False(t, db.Migrator().HasColumn("deleted_users", "enc_algo"))
|
||||
|
||||
require.NoError(t, migration.MigrateNewField[activity.DeletedUser](context.Background(), db, "enc_algo", ""))
|
||||
assert.True(t, db.Migrator().HasColumn("deleted_users", "enc_algo"))
|
||||
|
||||
t.Run("legacy users migration", func(t *testing.T) {
|
||||
legacyEmail := crypt.LegacyEncrypt("test.user@test.com")
|
||||
legacyName := crypt.LegacyEncrypt("Test User")
|
||||
|
||||
require.NoError(t, db.Exec(insertDeletedUserQuery, "user1", legacyEmail, legacyName, "").Error)
|
||||
require.NoError(t, db.Exec(insertDeletedUserQuery, "user2", legacyEmail, legacyName, "legacy").Error)
|
||||
|
||||
require.NoError(t, migrateLegacyEncryptedUsersToGCM(context.Background(), db, crypt))
|
||||
|
||||
var users []activity.DeletedUser
|
||||
require.NoError(t, db.Find(&users).Error)
|
||||
assert.Len(t, users, 2)
|
||||
|
||||
for _, user := range users {
|
||||
assert.Equal(t, gcmEncAlgo, user.EncAlgo)
|
||||
|
||||
decryptedEmail, err := crypt.Decrypt(user.Email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test.user@test.com", decryptedEmail)
|
||||
|
||||
decryptedName, err := crypt.Decrypt(user.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Test User", decryptedName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("users already migrated, no migration", func(t *testing.T) {
|
||||
encryptedEmail, err := crypt.Encrypt("test.user@test.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
encryptedName, err := crypt.Encrypt("Test User")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, db.Exec(insertDeletedUserQuery, "user3", encryptedEmail, encryptedName, gcmEncAlgo).Error)
|
||||
require.NoError(t, migrateLegacyEncryptedUsersToGCM(context.Background(), db, crypt))
|
||||
|
||||
var users []activity.DeletedUser
|
||||
require.NoError(t, db.Find(&users).Error)
|
||||
assert.Len(t, users, 3)
|
||||
|
||||
for _, user := range users {
|
||||
assert.Equal(t, gcmEncAlgo, user.EncAlgo)
|
||||
|
||||
decryptedEmail, err := crypt.Decrypt(user.Email)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test.user@test.com", decryptedEmail)
|
||||
|
||||
decryptedName, err := crypt.Decrypt(user.Name)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Test User", decryptedName)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateDuplicateDeletedUsers(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
|
||||
require.NoError(t, migrateDuplicateDeletedUsers(context.Background(), db))
|
||||
assert.False(t, db.Migrator().HasTable("deleted_users"))
|
||||
|
||||
require.NoError(t, db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`).Error)
|
||||
assert.True(t, db.Migrator().HasTable("deleted_users"))
|
||||
|
||||
isPrimaryKey, err := isColumnPrimaryKey[activity.DeletedUser](db, "id")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, isPrimaryKey)
|
||||
|
||||
require.NoError(t, db.Exec(insertDeletedUserQuery, "user1", "email1", "name1", "GCM").Error)
|
||||
require.NoError(t, db.Exec(insertDeletedUserQuery, "user1", "email2", "name2", "GCM").Error)
|
||||
require.NoError(t, migrateDuplicateDeletedUsers(context.Background(), db))
|
||||
|
||||
isPrimaryKey, err = isColumnPrimaryKey[activity.DeletedUser](db, "id")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, isPrimaryKey)
|
||||
|
||||
var users []activity.DeletedUser
|
||||
require.NoError(t, db.Find(&users).Error)
|
||||
assert.Len(t, users, 1)
|
||||
assert.Equal(t, "user1", users[0].ID)
|
||||
assert.Equal(t, "email2", users[0].Email)
|
||||
assert.Equal(t, "name2", users[0].Name)
|
||||
assert.Equal(t, "GCM", users[0].EncAlgo)
|
||||
}
|
||||
287
management/server/activity/store/sql_store.go
Normal file
287
management/server/activity/store/sql_store.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
// eventSinkDB is the default name of the events database
|
||||
eventSinkDB = "events.db"
|
||||
|
||||
fallbackName = "unknown"
|
||||
fallbackEmail = "unknown@unknown.com"
|
||||
|
||||
gcmEncAlgo = "GCM"
|
||||
|
||||
storeEngineEnv = "NB_ACTIVITY_EVENT_STORE_ENGINE"
|
||||
postgresDsnEnv = "NB_ACTIVITY_EVENT_POSTGRES_DSN"
|
||||
sqlMaxOpenConnsEnv = "NB_SQL_MAX_OPEN_CONNS"
|
||||
)
|
||||
|
||||
type eventWithNames struct {
|
||||
activity.Event
|
||||
InitiatorName string
|
||||
InitiatorEmail string
|
||||
TargetName string
|
||||
TargetEmail string
|
||||
}
|
||||
|
||||
// Store is the implementation of the activity.Store interface backed by SQLite
|
||||
type Store struct {
|
||||
db *gorm.DB
|
||||
fieldEncrypt *FieldEncrypt
|
||||
}
|
||||
|
||||
// NewSqlStore creates a new Store with an event table if not exists.
|
||||
func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
|
||||
crypt, err := NewFieldEncrypt(encryptionKey)
|
||||
if err != nil {
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := initDatabase(ctx, dataDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initialize database: %w", err)
|
||||
}
|
||||
|
||||
if err = migrate(ctx, crypt, db); err != nil {
|
||||
return nil, fmt.Errorf("events database migration: %w", err)
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&activity.Event{}, &activity.DeletedUser{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("events auto migrate: %w", err)
|
||||
}
|
||||
|
||||
return &Store{
|
||||
db: db,
|
||||
fieldEncrypt: crypt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (store *Store) processResult(ctx context.Context, events []*eventWithNames) ([]*activity.Event, error) {
|
||||
activityEvents := make([]*activity.Event, 0)
|
||||
var cryptErr error
|
||||
|
||||
for _, event := range events {
|
||||
e := event.Event
|
||||
if e.Meta == nil {
|
||||
e.Meta = make(map[string]any)
|
||||
}
|
||||
|
||||
if event.TargetName != "" {
|
||||
name, err := store.fieldEncrypt.Decrypt(event.TargetName)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt username for target id: %s", event.TargetName)
|
||||
e.Meta["username"] = fallbackName
|
||||
} else {
|
||||
e.Meta["username"] = name
|
||||
}
|
||||
}
|
||||
|
||||
if event.TargetEmail != "" {
|
||||
email, err := store.fieldEncrypt.Decrypt(event.TargetEmail)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt email address for target id: %s", event.TargetEmail)
|
||||
e.Meta["email"] = fallbackEmail
|
||||
} else {
|
||||
e.Meta["email"] = email
|
||||
}
|
||||
}
|
||||
|
||||
if event.InitiatorName != "" {
|
||||
name, err := store.fieldEncrypt.Decrypt(event.InitiatorName)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt username of initiator: %s", event.InitiatorName)
|
||||
e.InitiatorName = fallbackName
|
||||
} else {
|
||||
e.InitiatorName = name
|
||||
}
|
||||
}
|
||||
|
||||
if event.InitiatorEmail != "" {
|
||||
email, err := store.fieldEncrypt.Decrypt(event.InitiatorEmail)
|
||||
if err != nil {
|
||||
cryptErr = fmt.Errorf("failed to decrypt email address of initiator: %s", event.InitiatorEmail)
|
||||
e.InitiatorEmail = fallbackEmail
|
||||
} else {
|
||||
e.InitiatorEmail = email
|
||||
}
|
||||
}
|
||||
|
||||
activityEvents = append(activityEvents, &e)
|
||||
}
|
||||
|
||||
if cryptErr != nil {
|
||||
log.WithContext(ctx).Warnf("%s", cryptErr)
|
||||
}
|
||||
|
||||
return activityEvents, nil
|
||||
}
|
||||
|
||||
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp
|
||||
func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
|
||||
baseQuery := store.db.Model(&activity.Event{}).
|
||||
Select(`
|
||||
events.*,
|
||||
u.name AS initiator_name,
|
||||
u.email AS initiator_email,
|
||||
t.name AS target_name,
|
||||
t.email AS target_email
|
||||
`).
|
||||
Joins(`LEFT JOIN deleted_users u ON u.id = events.initiator_id`).
|
||||
Joins(`LEFT JOIN deleted_users t ON t.id = events.target_id`)
|
||||
|
||||
orderDir := "DESC"
|
||||
if !descending {
|
||||
orderDir = "ASC"
|
||||
}
|
||||
|
||||
var events []*eventWithNames
|
||||
err := baseQuery.Order("events.timestamp "+orderDir).Offset(offset).Limit(limit).
|
||||
Find(&events, "account_id = ?", accountID).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return store.processResult(ctx, events)
|
||||
}
|
||||
|
||||
// Save an event in the SQLite events table end encrypt the "email" element in meta map
|
||||
func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) {
|
||||
eventCopy := event.Copy()
|
||||
meta, err := store.saveDeletedUserEmailAndNameInEncrypted(eventCopy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventCopy.Meta = meta
|
||||
|
||||
if err = store.db.Create(eventCopy).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return eventCopy, nil
|
||||
}
|
||||
|
||||
// saveDeletedUserEmailAndNameInEncrypted if the meta contains email and name then store it in encrypted way and delete
|
||||
// this item from meta map
|
||||
func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event) (map[string]any, error) {
|
||||
email, ok := event.Meta["email"]
|
||||
if !ok {
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
name, ok := event.Meta["name"]
|
||||
if !ok {
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
deletedUser := activity.DeletedUser{
|
||||
ID: event.TargetID,
|
||||
EncAlgo: gcmEncAlgo,
|
||||
}
|
||||
|
||||
encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deletedUser.Email = encryptedEmail
|
||||
|
||||
encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deletedUser.Name = encryptedName
|
||||
|
||||
err = store.db.Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "id"}},
|
||||
DoUpdates: clause.AssignmentColumns([]string{"email", "name"}),
|
||||
}).Create(deletedUser).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(event.Meta) == 2 {
|
||||
return nil, nil // nolint
|
||||
}
|
||||
delete(event.Meta, "email")
|
||||
delete(event.Meta, "name")
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
// Close the Store
|
||||
func (store *Store) Close(_ context.Context) error {
|
||||
if store.db != nil {
|
||||
sql, err := store.db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
|
||||
var dialector gorm.Dialector
|
||||
var storeEngine = types.SqliteStoreEngine
|
||||
|
||||
if engine, ok := os.LookupEnv(storeEngineEnv); ok {
|
||||
storeEngine = types.Engine(engine)
|
||||
}
|
||||
|
||||
switch storeEngine {
|
||||
case types.SqliteStoreEngine:
|
||||
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB))
|
||||
case types.PostgresStoreEngine:
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s environment variable not set", postgresDsnEnv)
|
||||
}
|
||||
dialector = postgres.Open(dsn)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported store engine: %s", storeEngine)
|
||||
}
|
||||
log.WithContext(ctx).Infof("using %s as activity event store engine", storeEngine)
|
||||
|
||||
db, err := gorm.Open(dialector, &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open db connection: %w", err)
|
||||
}
|
||||
|
||||
return configureConnectionPool(db, storeEngine)
|
||||
}
|
||||
|
||||
func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, error) {
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if storeEngine == types.SqliteStoreEngine {
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
} else {
|
||||
conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
|
||||
if err != nil {
|
||||
conns = runtime.NumCPU()
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(conns)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
57
management/server/activity/store/sql_store_test.go
Normal file
57
management/server/activity/store/sql_store_test.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
)
|
||||
|
||||
func TestNewSqlStore(t *testing.T) {
|
||||
dataDir := t.TempDir()
|
||||
key, _ := GenerateKey()
|
||||
store, err := NewSqlStore(context.Background(), dataDir, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
defer store.Close(context.Background()) //nolint
|
||||
|
||||
accountID := "account_1"
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = store.Save(context.Background(), &activity.Event{
|
||||
Timestamp: time.Now().UTC(),
|
||||
Activity: activity.PeerAddedByUser,
|
||||
InitiatorID: "user_" + fmt.Sprint(i),
|
||||
TargetID: "peer_" + fmt.Sprint(i),
|
||||
AccountID: accountID,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
result, err := store.Get(context.Background(), accountID, 0, 10, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Len(t, result, 10)
|
||||
assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp))
|
||||
|
||||
result, err = store.Get(context.Background(), accountID, 0, 5, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Len(t, result, 5)
|
||||
assert.True(t, result[0].Timestamp.After(result[len(result)-1].Timestamp))
|
||||
}
|
||||
Reference in New Issue
Block a user