[management] Add postgres support for activity event store (#3890)

This commit is contained in:
Bethuel Mmbaga
2025-06-04 17:38:49 +03:00
committed by GitHub
parent ea4d13e96d
commit b604c66140
8 changed files with 92 additions and 36 deletions

View File

@@ -0,0 +1,136 @@
package store
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type FieldEncrypt struct {
block cipher.Block
gcm cipher.AEAD
}
func GenerateKey() (string, error) {
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
return "", err
}
readableKey := base64.StdEncoding.EncodeToString(key)
return readableKey, nil
}
func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
binKey, err := base64.StdEncoding.DecodeString(key)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(binKey)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
ec := &FieldEncrypt{
block: block,
gcm: gcm,
}
return ec, nil
}
func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, plainText)
return base64.StdEncoding.EncodeToString(cipherText)
}
// Encrypt encrypts plaintext using AES-GCM
func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
plaintext := []byte(payload)
nonceSize := ec.gcm.NonceSize()
nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
if _, err := rand.Read(nonce); err != nil {
return "", err
}
ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
cbc := cipher.NewCBCDecrypter(ec.block, iv)
cbc.CryptBlocks(cipherText, cipherText)
payload, err := pkcs5UnPadding(cipherText)
if err != nil {
return "", err
}
return string(payload), nil
}
// Decrypt decrypts ciphertext using AES-GCM
func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
}
nonceSize := ec.gcm.NonceSize()
if len(cipherText) < nonceSize {
return "", errors.New("cipher text too short")
}
nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
if err != nil {
return "", err
}
return string(plainText), nil
}
func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
if srcLen == 0 {
return nil, errors.New("input data is empty")
}
paddingLen := int(src[srcLen-1])
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
return nil, errors.New("invalid padding size")
}
// Verify that all padding bytes are the same
for i := 0; i < paddingLen; i++ {
if src[srcLen-1-i] != byte(paddingLen) {
return nil, errors.New("invalid padding")
}
}
return src[:srcLen-paddingLen], nil
}

View File

@@ -0,0 +1,310 @@
package store
import (
"bytes"
"testing"
)
func TestGenerateKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.Decrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestGenerateKeyLegacy(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted := ee.LegacyEncrypt(testData)
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
decrypted, err := ee.LegacyDecrypt(encrypted)
if err != nil {
t.Fatalf("failed to decrypt data: %s", err)
}
if decrypted != testData {
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
}
}
func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
encrypted, err := ee.Encrypt(testData)
if err != nil {
t.Fatalf("failed to encrypt data: %s", err)
}
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
newKey, err := GenerateKey()
if err != nil {
t.Fatalf("failed to generate key: %s", err)
}
ee, err = NewFieldEncrypt(newKey)
if err != nil {
t.Fatalf("failed to init email encryption: %s", err)
}
res, _ := ee.Decrypt(encrypted)
if res == testData {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}
func TestEncryptDecrypt(t *testing.T) {
// Generate a key for encryption/decryption
key, err := GenerateKey()
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
// Initialize the FieldEncrypt with the generated key
ec, err := NewFieldEncrypt(key)
if err != nil {
t.Fatalf("Failed to create FieldEncrypt: %v", err)
}
// Test cases
testCases := []struct {
name string
input string
}{
{
name: "Empty String",
input: "",
},
{
name: "Short String",
input: "Hello",
},
{
name: "String with Spaces",
input: "Hello, World!",
},
{
name: "Long String",
input: "The quick brown fox jumps over the lazy dog.",
},
{
name: "Unicode Characters",
input: "こんにちは世界",
},
{
name: "Special Characters",
input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
},
{
name: "Numeric String",
input: "1234567890",
},
{
name: "Repeated Characters",
input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
},
{
name: "Multi-block String",
input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
},
{
name: "Non-ASCII and ASCII Mix",
input: "Hello 世界 123",
},
}
for _, tc := range testCases {
t.Run(tc.name+" - Legacy", func(t *testing.T) {
// Legacy Encryption
encryptedLegacy := ec.LegacyEncrypt(tc.input)
if encryptedLegacy == "" {
t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
}
// Legacy Decryption
decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
if err != nil {
t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedLegacy != tc.input {
t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
}
})
t.Run(tc.name+" - New", func(t *testing.T) {
// New Encryption
encryptedNew, err := ec.Encrypt(tc.input)
if err != nil {
t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
}
if encryptedNew == "" {
t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
}
// New Decryption
decryptedNew, err := ec.Decrypt(encryptedNew)
if err != nil {
t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
}
// Verify that the decrypted value matches the original input
if decryptedNew != tc.input {
t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
}
})
}
}
func TestPKCS5UnPadding(t *testing.T) {
tests := []struct {
name string
input []byte
expected []byte
expectError bool
}{
{
name: "Valid Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
expected: []byte("Hello, World!"),
},
{
name: "Empty Input",
input: []byte{},
expectError: true,
},
{
name: "Padding Length Zero",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
expectError: true,
},
{
name: "Padding Length Exceeds Block Size",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
expectError: true,
},
{
name: "Padding Length Exceeds Input Length",
input: []byte{5, 5, 5},
expectError: true,
},
{
name: "Invalid Padding Bytes",
input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
expectError: true,
},
{
name: "Valid Single Byte Padding",
input: append([]byte("Hello, World!"), byte(1)),
expected: []byte("Hello, World!"),
},
{
name: "Invalid Mixed Padding Bytes",
input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
expectError: true,
},
{
name: "Valid Full Block Padding",
input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
expected: []byte("Hello, World!"),
},
{
name: "Non-Padding Byte at End",
input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
expectError: true,
},
{
name: "Valid Padding with Different Text Length",
input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
expected: []byte("Test"),
},
{
name: "Padding Length Equal to Input Length",
input: bytes.Repeat([]byte{8}, 8),
expected: []byte{},
},
{
name: "Invalid Padding Length Zero (Again)",
input: append([]byte("Test"), byte(0)),
expectError: true,
},
{
name: "Padding Length Greater Than Input",
input: []byte{10},
expectError: true,
},
{
name: "Input Length Not Multiple of Block Size",
input: append([]byte("Invalid Length"), byte(1)),
expected: []byte("Invalid Length"),
},
{
name: "Valid Padding with Non-ASCII Characters",
input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
expected: []byte("こんにちは"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := pkcs5UnPadding(tt.input)
if tt.expectError {
if err == nil {
t.Errorf("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Did not expect error but got: %v", err)
}
if !bytes.Equal(result, tt.expected) {
t.Errorf("Expected output %v, got %v", tt.expected, result)
}
}
})
}
}

View File

@@ -0,0 +1,185 @@
package store
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/migration"
)
func migrate(ctx context.Context, crypt *FieldEncrypt, db *gorm.DB) error {
migrations := getMigrations(ctx, crypt)
for _, m := range migrations {
if err := m(db); err != nil {
return err
}
}
return nil
}
type migrationFunc func(*gorm.DB) error
func getMigrations(ctx context.Context, crypt *FieldEncrypt) []migrationFunc {
return []migrationFunc{
func(db *gorm.DB) error {
return migration.MigrateNewField[activity.DeletedUser](ctx, db, "name", "")
},
func(db *gorm.DB) error {
return migration.MigrateNewField[activity.DeletedUser](ctx, db, "enc_algo", "")
},
func(db *gorm.DB) error {
return migrateLegacyEncryptedUsersToGCM(ctx, db, crypt)
},
func(db *gorm.DB) error {
return migrateDuplicateDeletedUsers(ctx, db)
},
}
}
// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using
// legacy CBC encryption with a static IV to the new GCM encryption method.
func migrateLegacyEncryptedUsersToGCM(ctx context.Context, db *gorm.DB, crypt *FieldEncrypt) error {
model := &activity.DeletedUser{}
if !db.Migrator().HasTable(model) {
log.WithContext(ctx).Debugf("Table for %T does not exist, no CBC to GCM migration needed", model)
return nil
}
var deletedUsers []activity.DeletedUser
err := db.Model(model).Find(&deletedUsers, "enc_algo IS NULL OR enc_algo != ?", gcmEncAlgo).Error
if err != nil {
return fmt.Errorf("failed to query deleted_users: %w", err)
}
if len(deletedUsers) == 0 {
log.WithContext(ctx).Debug("No CBC encrypted deleted users to migrate")
return nil
}
if err = db.Transaction(func(tx *gorm.DB) error {
for _, user := range deletedUsers {
if err = updateDeletedUserData(tx, user, crypt); err != nil {
return fmt.Errorf("failed to migrate deleted user %s: %w", user.ID, err)
}
}
return nil
}); err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
return nil
}
func updateDeletedUserData(transaction *gorm.DB, user activity.DeletedUser, crypt *FieldEncrypt) error {
var err error
var decryptedEmail, decryptedName string
if user.Email != "" {
decryptedEmail, err = crypt.LegacyDecrypt(user.Email)
if err != nil {
return fmt.Errorf("failed to decrypt email: %w", err)
}
}
if user.Name != "" {
decryptedName, err = crypt.LegacyDecrypt(user.Name)
if err != nil {
return fmt.Errorf("failed to decrypt name: %w", err)
}
}
updatedUser := user
updatedUser.EncAlgo = gcmEncAlgo
updatedUser.Email, err = crypt.Encrypt(decryptedEmail)
if err != nil {
return fmt.Errorf("failed to encrypt email: %w", err)
}
updatedUser.Name, err = crypt.Encrypt(decryptedName)
if err != nil {
return fmt.Errorf("failed to encrypt name: %w", err)
}
return transaction.Model(&updatedUser).Omit("id").Updates(updatedUser).Error
}
// MigrateDuplicateDeletedUsers removes duplicates and ensures the id column is marked as the primary key
func migrateDuplicateDeletedUsers(ctx context.Context, db *gorm.DB) error {
model := &activity.DeletedUser{}
if !db.Migrator().HasTable(model) {
log.WithContext(ctx).Debugf("Table for %T does not exist, no duplicate migration needed", model)
return nil
}
isPrimaryKey, err := isColumnPrimaryKey[activity.DeletedUser](db, "id")
if err != nil {
return err
}
if isPrimaryKey {
log.WithContext(ctx).Debug("No duplicate deleted users to migrate")
return nil
}
if err = db.Transaction(func(tx *gorm.DB) error {
if err = tx.Migrator().RenameTable("deleted_users", "deleted_users_old"); err != nil {
return err
}
if err = tx.Migrator().CreateTable(model); err != nil {
return err
}
var deletedUsers []activity.DeletedUser
if err = tx.Table("deleted_users_old").Find(&deletedUsers).Error; err != nil {
return err
}
for _, deletedUser := range deletedUsers {
if err = tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
DoUpdates: clause.AssignmentColumns([]string{"email", "name", "enc_algo"}),
}).Create(&deletedUser).Error; err != nil {
return err
}
}
return tx.Migrator().DropTable("deleted_users_old")
}); err != nil {
return err
}
log.WithContext(ctx).Debug("Successfully migrated duplicate deleted users")
return nil
}
// isColumnPrimaryKey checks if a column is a primary key in the given model
func isColumnPrimaryKey[T any](db *gorm.DB, columnName string) (bool, error) {
var model T
cols, err := db.Migrator().ColumnTypes(&model)
if err != nil {
return false, err
}
for _, col := range cols {
if col.Name() == columnName {
isPrimaryKey, _ := col.PrimaryKey()
return isPrimaryKey, nil
}
}
return false, nil
}

View File

@@ -0,0 +1,143 @@
package store
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/migration"
"github.com/netbirdio/netbird/management/server/testutil"
)
const (
insertDeletedUserQuery = `INSERT INTO deleted_users (id, email, name, enc_algo) VALUES (?, ?, ?, ?)`
)
func setupDatabase(t *testing.T) *gorm.DB {
t.Helper()
cleanup, dsn, err := testutil.CreatePostgresTestContainer()
require.NoError(t, err, "Failed to create Postgres test container")
t.Cleanup(cleanup)
db, err := gorm.Open(postgres.Open(dsn))
require.NoError(t, err)
sql, err := db.DB()
require.NoError(t, err)
t.Cleanup(func() {
_ = sql.Close()
})
return db
}
func TestMigrateLegacyEncryptedUsersToGCM(t *testing.T) {
db := setupDatabase(t)
key, err := GenerateKey()
require.NoError(t, err, "Failed to generate key")
crypt, err := NewFieldEncrypt(key)
require.NoError(t, err, "Failed to initialize FieldEncrypt")
t.Run("empty table, no migration required", func(t *testing.T) {
require.NoError(t, migrateLegacyEncryptedUsersToGCM(context.Background(), db, crypt))
assert.False(t, db.Migrator().HasTable("deleted_users"))
})
require.NoError(t, db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`).Error)
assert.True(t, db.Migrator().HasTable("deleted_users"))
assert.False(t, db.Migrator().HasColumn("deleted_users", "enc_algo"))
require.NoError(t, migration.MigrateNewField[activity.DeletedUser](context.Background(), db, "enc_algo", ""))
assert.True(t, db.Migrator().HasColumn("deleted_users", "enc_algo"))
t.Run("legacy users migration", func(t *testing.T) {
legacyEmail := crypt.LegacyEncrypt("test.user@test.com")
legacyName := crypt.LegacyEncrypt("Test User")
require.NoError(t, db.Exec(insertDeletedUserQuery, "user1", legacyEmail, legacyName, "").Error)
require.NoError(t, db.Exec(insertDeletedUserQuery, "user2", legacyEmail, legacyName, "legacy").Error)
require.NoError(t, migrateLegacyEncryptedUsersToGCM(context.Background(), db, crypt))
var users []activity.DeletedUser
require.NoError(t, db.Find(&users).Error)
assert.Len(t, users, 2)
for _, user := range users {
assert.Equal(t, gcmEncAlgo, user.EncAlgo)
decryptedEmail, err := crypt.Decrypt(user.Email)
require.NoError(t, err)
assert.Equal(t, "test.user@test.com", decryptedEmail)
decryptedName, err := crypt.Decrypt(user.Name)
require.NoError(t, err)
require.Equal(t, "Test User", decryptedName)
}
})
t.Run("users already migrated, no migration", func(t *testing.T) {
encryptedEmail, err := crypt.Encrypt("test.user@test.com")
require.NoError(t, err)
encryptedName, err := crypt.Encrypt("Test User")
require.NoError(t, err)
require.NoError(t, db.Exec(insertDeletedUserQuery, "user3", encryptedEmail, encryptedName, gcmEncAlgo).Error)
require.NoError(t, migrateLegacyEncryptedUsersToGCM(context.Background(), db, crypt))
var users []activity.DeletedUser
require.NoError(t, db.Find(&users).Error)
assert.Len(t, users, 3)
for _, user := range users {
assert.Equal(t, gcmEncAlgo, user.EncAlgo)
decryptedEmail, err := crypt.Decrypt(user.Email)
require.NoError(t, err)
assert.Equal(t, "test.user@test.com", decryptedEmail)
decryptedName, err := crypt.Decrypt(user.Name)
require.NoError(t, err)
require.Equal(t, "Test User", decryptedName)
}
})
}
func TestMigrateDuplicateDeletedUsers(t *testing.T) {
db := setupDatabase(t)
require.NoError(t, migrateDuplicateDeletedUsers(context.Background(), db))
assert.False(t, db.Migrator().HasTable("deleted_users"))
require.NoError(t, db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`).Error)
assert.True(t, db.Migrator().HasTable("deleted_users"))
isPrimaryKey, err := isColumnPrimaryKey[activity.DeletedUser](db, "id")
require.NoError(t, err)
assert.False(t, isPrimaryKey)
require.NoError(t, db.Exec(insertDeletedUserQuery, "user1", "email1", "name1", "GCM").Error)
require.NoError(t, db.Exec(insertDeletedUserQuery, "user1", "email2", "name2", "GCM").Error)
require.NoError(t, migrateDuplicateDeletedUsers(context.Background(), db))
isPrimaryKey, err = isColumnPrimaryKey[activity.DeletedUser](db, "id")
require.NoError(t, err)
assert.True(t, isPrimaryKey)
var users []activity.DeletedUser
require.NoError(t, db.Find(&users).Error)
assert.Len(t, users, 1)
assert.Equal(t, "user1", users[0].ID)
assert.Equal(t, "email2", users[0].Email)
assert.Equal(t, "name2", users[0].Name)
assert.Equal(t, "GCM", users[0].EncAlgo)
}

View File

@@ -0,0 +1,287 @@
package store
import (
"context"
"fmt"
"os"
"path/filepath"
"runtime"
"strconv"
log "github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/types"
)
const (
// eventSinkDB is the default name of the events database
eventSinkDB = "events.db"
fallbackName = "unknown"
fallbackEmail = "unknown@unknown.com"
gcmEncAlgo = "GCM"
storeEngineEnv = "NB_ACTIVITY_EVENT_STORE_ENGINE"
postgresDsnEnv = "NB_ACTIVITY_EVENT_POSTGRES_DSN"
sqlMaxOpenConnsEnv = "NB_SQL_MAX_OPEN_CONNS"
)
type eventWithNames struct {
activity.Event
InitiatorName string
InitiatorEmail string
TargetName string
TargetEmail string
}
// Store is the implementation of the activity.Store interface backed by SQLite
type Store struct {
db *gorm.DB
fieldEncrypt *FieldEncrypt
}
// NewSqlStore creates a new Store with an event table if not exists.
func NewSqlStore(ctx context.Context, dataDir string, encryptionKey string) (*Store, error) {
crypt, err := NewFieldEncrypt(encryptionKey)
if err != nil {
return nil, err
}
db, err := initDatabase(ctx, dataDir)
if err != nil {
return nil, fmt.Errorf("initialize database: %w", err)
}
if err = migrate(ctx, crypt, db); err != nil {
return nil, fmt.Errorf("events database migration: %w", err)
}
err = db.AutoMigrate(&activity.Event{}, &activity.DeletedUser{})
if err != nil {
return nil, fmt.Errorf("events auto migrate: %w", err)
}
return &Store{
db: db,
fieldEncrypt: crypt,
}, nil
}
func (store *Store) processResult(ctx context.Context, events []*eventWithNames) ([]*activity.Event, error) {
activityEvents := make([]*activity.Event, 0)
var cryptErr error
for _, event := range events {
e := event.Event
if e.Meta == nil {
e.Meta = make(map[string]any)
}
if event.TargetName != "" {
name, err := store.fieldEncrypt.Decrypt(event.TargetName)
if err != nil {
cryptErr = fmt.Errorf("failed to decrypt username for target id: %s", event.TargetName)
e.Meta["username"] = fallbackName
} else {
e.Meta["username"] = name
}
}
if event.TargetEmail != "" {
email, err := store.fieldEncrypt.Decrypt(event.TargetEmail)
if err != nil {
cryptErr = fmt.Errorf("failed to decrypt email address for target id: %s", event.TargetEmail)
e.Meta["email"] = fallbackEmail
} else {
e.Meta["email"] = email
}
}
if event.InitiatorName != "" {
name, err := store.fieldEncrypt.Decrypt(event.InitiatorName)
if err != nil {
cryptErr = fmt.Errorf("failed to decrypt username of initiator: %s", event.InitiatorName)
e.InitiatorName = fallbackName
} else {
e.InitiatorName = name
}
}
if event.InitiatorEmail != "" {
email, err := store.fieldEncrypt.Decrypt(event.InitiatorEmail)
if err != nil {
cryptErr = fmt.Errorf("failed to decrypt email address of initiator: %s", event.InitiatorEmail)
e.InitiatorEmail = fallbackEmail
} else {
e.InitiatorEmail = email
}
}
activityEvents = append(activityEvents, &e)
}
if cryptErr != nil {
log.WithContext(ctx).Warnf("%s", cryptErr)
}
return activityEvents, nil
}
// Get returns "limit" number of events from index ordered descending or ascending by a timestamp
func (store *Store) Get(ctx context.Context, accountID string, offset, limit int, descending bool) ([]*activity.Event, error) {
baseQuery := store.db.Model(&activity.Event{}).
Select(`
events.*,
u.name AS initiator_name,
u.email AS initiator_email,
t.name AS target_name,
t.email AS target_email
`).
Joins(`LEFT JOIN deleted_users u ON u.id = events.initiator_id`).
Joins(`LEFT JOIN deleted_users t ON t.id = events.target_id`)
orderDir := "DESC"
if !descending {
orderDir = "ASC"
}
var events []*eventWithNames
err := baseQuery.Order("events.timestamp "+orderDir).Offset(offset).Limit(limit).
Find(&events, "account_id = ?", accountID).Error
if err != nil {
return nil, err
}
return store.processResult(ctx, events)
}
// Save an event in the SQLite events table end encrypt the "email" element in meta map
func (store *Store) Save(_ context.Context, event *activity.Event) (*activity.Event, error) {
eventCopy := event.Copy()
meta, err := store.saveDeletedUserEmailAndNameInEncrypted(eventCopy)
if err != nil {
return nil, err
}
eventCopy.Meta = meta
if err = store.db.Create(eventCopy).Error; err != nil {
return nil, err
}
return eventCopy, nil
}
// saveDeletedUserEmailAndNameInEncrypted if the meta contains email and name then store it in encrypted way and delete
// this item from meta map
func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event) (map[string]any, error) {
email, ok := event.Meta["email"]
if !ok {
return event.Meta, nil
}
name, ok := event.Meta["name"]
if !ok {
return event.Meta, nil
}
deletedUser := activity.DeletedUser{
ID: event.TargetID,
EncAlgo: gcmEncAlgo,
}
encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
if err != nil {
return nil, err
}
deletedUser.Email = encryptedEmail
encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
if err != nil {
return nil, err
}
deletedUser.Name = encryptedName
err = store.db.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
DoUpdates: clause.AssignmentColumns([]string{"email", "name"}),
}).Create(deletedUser).Error
if err != nil {
return nil, err
}
if len(event.Meta) == 2 {
return nil, nil // nolint
}
delete(event.Meta, "email")
delete(event.Meta, "name")
return event.Meta, nil
}
// Close the Store
func (store *Store) Close(_ context.Context) error {
if store.db != nil {
sql, err := store.db.DB()
if err != nil {
return err
}
return sql.Close()
}
return nil
}
func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
var dialector gorm.Dialector
var storeEngine = types.SqliteStoreEngine
if engine, ok := os.LookupEnv(storeEngineEnv); ok {
storeEngine = types.Engine(engine)
}
switch storeEngine {
case types.SqliteStoreEngine:
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB))
case types.PostgresStoreEngine:
dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {
return nil, fmt.Errorf("%s environment variable not set", postgresDsnEnv)
}
dialector = postgres.Open(dsn)
default:
return nil, fmt.Errorf("unsupported store engine: %s", storeEngine)
}
log.WithContext(ctx).Infof("using %s as activity event store engine", storeEngine)
db, err := gorm.Open(dialector, &gorm.Config{Logger: logger.Default.LogMode(logger.Silent)})
if err != nil {
return nil, fmt.Errorf("open db connection: %w", err)
}
return configureConnectionPool(db, storeEngine)
}
func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, error) {
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
if storeEngine == types.SqliteStoreEngine {
sqlDB.SetMaxOpenConns(1)
} else {
conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
if err != nil {
conns = runtime.NumCPU()
}
sqlDB.SetMaxOpenConns(conns)
}
return db, nil
}

View File

@@ -0,0 +1,57 @@
package store
import (
"context"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/activity"
)
func TestNewSqlStore(t *testing.T) {
dataDir := t.TempDir()
key, _ := GenerateKey()
store, err := NewSqlStore(context.Background(), dataDir, key)
if err != nil {
t.Fatal(err)
return
}
defer store.Close(context.Background()) //nolint
accountID := "account_1"
for i := 0; i < 10; i++ {
_, err = store.Save(context.Background(), &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user_" + fmt.Sprint(i),
TargetID: "peer_" + fmt.Sprint(i),
AccountID: accountID,
})
if err != nil {
t.Fatal(err)
return
}
}
result, err := store.Get(context.Background(), accountID, 0, 10, false)
if err != nil {
t.Fatal(err)
return
}
assert.Len(t, result, 10)
assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp))
result, err = store.Get(context.Background(), accountID, 0, 5, true)
if err != nil {
t.Fatal(err)
return
}
assert.Len(t, result, 5)
assert.True(t, result[0].Timestamp.After(result[len(result)-1].Timestamp))
}