mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[management, infrastructure, idp] Simplified IdP Management - Embedded IdP (#5008)
Embed Dex as a built-in IdP to simplify self-hosting setup. Adds an embedded OIDC Identity Provider (Dex) with local user management and optional external IdP connectors (Google/GitHub/OIDC/SAML), plus device-auth flow for CLI login. Introduces instance onboarding/setup endpoints (including owner creation), field-level encryption for sensitive user data, a streamlined self-hosting provisioning script, and expanded APIs + test coverage for IdP management. more at https://github.com/netbirdio/netbird/pull/5008#issuecomment-3718987393
This commit is contained in:
122
management/server/types/identity_provider.go
Normal file
122
management/server/types/identity_provider.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Identity provider validation errors
|
||||
var (
|
||||
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
|
||||
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
|
||||
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
|
||||
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
|
||||
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
|
||||
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
|
||||
)
|
||||
|
||||
// IdentityProviderType is the type of identity provider
|
||||
type IdentityProviderType string
|
||||
|
||||
const (
|
||||
// IdentityProviderTypeOIDC is a generic OIDC identity provider
|
||||
IdentityProviderTypeOIDC IdentityProviderType = "oidc"
|
||||
// IdentityProviderTypeZitadel is the Zitadel identity provider
|
||||
IdentityProviderTypeZitadel IdentityProviderType = "zitadel"
|
||||
// IdentityProviderTypeEntra is the Microsoft Entra (Azure AD) identity provider
|
||||
IdentityProviderTypeEntra IdentityProviderType = "entra"
|
||||
// IdentityProviderTypeGoogle is the Google identity provider
|
||||
IdentityProviderTypeGoogle IdentityProviderType = "google"
|
||||
// IdentityProviderTypeOkta is the Okta identity provider
|
||||
IdentityProviderTypeOkta IdentityProviderType = "okta"
|
||||
// IdentityProviderTypePocketID is the PocketID identity provider
|
||||
IdentityProviderTypePocketID IdentityProviderType = "pocketid"
|
||||
// IdentityProviderTypeMicrosoft is the Microsoft identity provider
|
||||
IdentityProviderTypeMicrosoft IdentityProviderType = "microsoft"
|
||||
// IdentityProviderTypeAuthentik is the Authentik identity provider
|
||||
IdentityProviderTypeAuthentik IdentityProviderType = "authentik"
|
||||
// IdentityProviderTypeKeycloak is the Keycloak identity provider
|
||||
IdentityProviderTypeKeycloak IdentityProviderType = "keycloak"
|
||||
)
|
||||
|
||||
// IdentityProvider represents an identity provider configuration
|
||||
type IdentityProvider struct {
|
||||
// ID is the unique identifier of the identity provider
|
||||
ID string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
// Type is the type of identity provider
|
||||
Type IdentityProviderType
|
||||
// Name is a human-readable name for the identity provider
|
||||
Name string
|
||||
// Issuer is the OIDC issuer URL
|
||||
Issuer string
|
||||
// ClientID is the OAuth2 client ID
|
||||
ClientID string
|
||||
// ClientSecret is the OAuth2 client secret
|
||||
ClientSecret string
|
||||
}
|
||||
|
||||
// Copy returns a copy of the IdentityProvider
|
||||
func (idp *IdentityProvider) Copy() *IdentityProvider {
|
||||
return &IdentityProvider{
|
||||
ID: idp.ID,
|
||||
AccountID: idp.AccountID,
|
||||
Type: idp.Type,
|
||||
Name: idp.Name,
|
||||
Issuer: idp.Issuer,
|
||||
ClientID: idp.ClientID,
|
||||
ClientSecret: idp.ClientSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// EventMeta returns a map of metadata for activity events
|
||||
func (idp *IdentityProvider) EventMeta() map[string]any {
|
||||
return map[string]any{
|
||||
"name": idp.Name,
|
||||
"type": string(idp.Type),
|
||||
"issuer": idp.Issuer,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the identity provider configuration
|
||||
func (idp *IdentityProvider) Validate() error {
|
||||
if idp.Name == "" {
|
||||
return ErrIdentityProviderNameRequired
|
||||
}
|
||||
if idp.Type == "" {
|
||||
return ErrIdentityProviderTypeRequired
|
||||
}
|
||||
if !idp.Type.IsValid() {
|
||||
return ErrIdentityProviderTypeUnsupported
|
||||
}
|
||||
if !idp.Type.HasBuiltInIssuer() && idp.Issuer == "" {
|
||||
return ErrIdentityProviderIssuerRequired
|
||||
}
|
||||
if idp.Issuer != "" {
|
||||
parsedURL, err := url.Parse(idp.Issuer)
|
||||
if err != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
return ErrIdentityProviderIssuerInvalid
|
||||
}
|
||||
}
|
||||
if idp.ClientID == "" {
|
||||
return ErrIdentityProviderClientIDRequired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValid checks if the given type is a supported identity provider type
|
||||
func (t IdentityProviderType) IsValid() bool {
|
||||
switch t {
|
||||
case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra,
|
||||
IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID,
|
||||
IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasBuiltInIssuer returns true for types that don't require an issuer URL
|
||||
func (t IdentityProviderType) HasBuiltInIssuer() bool {
|
||||
return t == IdentityProviderTypeGoogle || t == IdentityProviderTypeMicrosoft
|
||||
}
|
||||
137
management/server/types/identity_provider_test.go
Normal file
137
management/server/types/identity_provider_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIdentityProvider_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
idp *IdentityProvider
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "valid OIDC provider",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid OIDC provider with path",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com/oauth2/issuer",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
idp: &IdentityProvider{
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderNameRequired,
|
||||
},
|
||||
{
|
||||
name: "missing type",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderTypeRequired,
|
||||
},
|
||||
{
|
||||
name: "invalid type",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: "invalid",
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderTypeUnsupported,
|
||||
},
|
||||
{
|
||||
name: "missing issuer for OIDC",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerRequired,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer URL - no scheme",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer URL - no host",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer URL - just path",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "/oauth2/issuer",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderClientIDRequired,
|
||||
},
|
||||
{
|
||||
name: "Google provider without issuer is valid",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Google SSO",
|
||||
Type: IdentityProviderTypeGoogle,
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Microsoft provider without issuer is valid",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Microsoft SSO",
|
||||
Type: IdentityProviderTypeMicrosoft,
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.idp.Validate()
|
||||
assert.Equal(t, tt.expectedErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -65,7 +66,11 @@ type UserInfo struct {
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
Issued string `json:"issued"`
|
||||
PendingApproval bool `json:"pending_approval"`
|
||||
Password string `json:"password"`
|
||||
IntegrationReference integration_reference.IntegrationReference `json:"-"`
|
||||
// IdPID is the identity provider ID (connector ID) extracted from the Dex-encoded user ID.
|
||||
// This field is only populated when the user ID can be decoded from Dex's format.
|
||||
IdPID string `json:"idp_id,omitempty"`
|
||||
}
|
||||
|
||||
// User represents a user of the system
|
||||
@@ -96,6 +101,9 @@ type User struct {
|
||||
Issued string `gorm:"default:api"`
|
||||
|
||||
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
|
||||
|
||||
Name string `gorm:"default:''"`
|
||||
Email string `gorm:"default:''"`
|
||||
}
|
||||
|
||||
// IsBlocked returns true if the user is blocked, false otherwise
|
||||
@@ -143,10 +151,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
}
|
||||
|
||||
if userData == nil {
|
||||
|
||||
name := u.Name
|
||||
if u.IsServiceUser {
|
||||
name = u.ServiceUserName
|
||||
}
|
||||
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: u.ServiceUserName,
|
||||
Email: u.Email,
|
||||
Name: name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
@@ -178,6 +192,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
LastLogin: u.GetLastLogin(),
|
||||
Issued: u.Issued,
|
||||
PendingApproval: u.PendingApproval,
|
||||
Password: userData.Password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -204,11 +219,13 @@ func (u *User) Copy() *User {
|
||||
CreatedAt: u.CreatedAt,
|
||||
Issued: u.Issued,
|
||||
IntegrationReference: u.IntegrationReference,
|
||||
Email: u.Email,
|
||||
Name: u.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUser creates a new user
|
||||
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
|
||||
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string, email string, name string) *User {
|
||||
return &User{
|
||||
Id: id,
|
||||
Role: role,
|
||||
@@ -218,20 +235,70 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
|
||||
AutoGroups: autoGroups,
|
||||
Issued: issued,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Name: name,
|
||||
Email: email,
|
||||
}
|
||||
}
|
||||
|
||||
// NewRegularUser creates a new user with role UserRoleUser
|
||||
func NewRegularUser(id string) *User {
|
||||
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
|
||||
func NewRegularUser(id, email, name string) *User {
|
||||
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI, email, name)
|
||||
}
|
||||
|
||||
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||
func NewAdminUser(id string) *User {
|
||||
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
|
||||
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI, "", "")
|
||||
}
|
||||
|
||||
// NewOwnerUser creates a new user with role UserRoleOwner
|
||||
func NewOwnerUser(id string) *User {
|
||||
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
|
||||
func NewOwnerUser(id string, email string, name string) *User {
|
||||
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI, email, name)
|
||||
}
|
||||
|
||||
// EncryptSensitiveData encrypts the user's sensitive fields (Email and Name) in place.
|
||||
func (u *User) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if u.Email != "" {
|
||||
u.Email, err = enc.Encrypt(u.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if u.Name != "" {
|
||||
u.Name, err = enc.Encrypt(u.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptSensitiveData decrypts the user's sensitive fields (Email and Name) in place.
|
||||
func (u *User) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if u.Email != "" {
|
||||
u.Email, err = enc.Decrypt(u.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if u.Name != "" {
|
||||
u.Name, err = enc.Decrypt(u.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
298
management/server/types/user_test.go
Normal file
298
management/server/types/user_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
func TestUser_EncryptSensitiveData(t *testing.T) {
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("encrypt email and name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-1",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted")
|
||||
assert.NotEqual(t, "Test User", user.Name, "name should be encrypted")
|
||||
assert.NotEmpty(t, user.Email, "encrypted email should not be empty")
|
||||
assert.NotEmpty(t, user.Name, "encrypted name should not be empty")
|
||||
})
|
||||
|
||||
t.Run("encrypt empty email and name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-2",
|
||||
Email: "",
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("encrypt only email", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-3",
|
||||
Email: "test@example.com",
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted")
|
||||
assert.NotEmpty(t, user.Email, "encrypted email should not be empty")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("encrypt only name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-4",
|
||||
Email: "",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.NotEqual(t, "Test User", user.Name, "name should be encrypted")
|
||||
assert.NotEmpty(t, user.Name, "encrypted name should not be empty")
|
||||
})
|
||||
|
||||
t.Run("nil encryptor returns no error", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-5",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor")
|
||||
assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_DecryptSensitiveData(t *testing.T) {
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("decrypt email and name", func(t *testing.T) {
|
||||
originalEmail := "test@example.com"
|
||||
originalName := "Test User"
|
||||
|
||||
user := &User{
|
||||
Id: "user-1",
|
||||
Email: originalEmail,
|
||||
Name: originalName,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalEmail, user.Email, "decrypted email should match original")
|
||||
assert.Equal(t, originalName, user.Name, "decrypted name should match original")
|
||||
})
|
||||
|
||||
t.Run("decrypt empty email and name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-2",
|
||||
Email: "",
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("decrypt only email", func(t *testing.T) {
|
||||
originalEmail := "test@example.com"
|
||||
|
||||
user := &User{
|
||||
Id: "user-3",
|
||||
Email: originalEmail,
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalEmail, user.Email, "decrypted email should match original")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("decrypt only name", func(t *testing.T) {
|
||||
originalName := "Test User"
|
||||
|
||||
user := &User{
|
||||
Id: "user-4",
|
||||
Email: "",
|
||||
Name: originalName,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.Equal(t, originalName, user.Name, "decrypted name should match original")
|
||||
})
|
||||
|
||||
t.Run("nil encryptor returns no error", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-5",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.DecryptSensitiveData(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor")
|
||||
assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor")
|
||||
})
|
||||
|
||||
t.Run("decrypt with invalid ciphertext returns error", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-6",
|
||||
Email: "not-valid-base64-ciphertext!!!",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decrypt email")
|
||||
})
|
||||
|
||||
t.Run("decrypt with wrong key returns error", func(t *testing.T) {
|
||||
originalEmail := "test@example.com"
|
||||
originalName := "Test User"
|
||||
|
||||
user := &User{
|
||||
Id: "user-7",
|
||||
Email: originalEmail,
|
||||
Name: originalName,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
differentKey, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
differentEncrypt, err := crypt.NewFieldEncrypt(differentKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(differentEncrypt)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decrypt email")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_EncryptDecryptRoundTrip(t *testing.T) {
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
email string
|
||||
uname string
|
||||
}{
|
||||
{
|
||||
name: "standard email and name",
|
||||
email: "user@example.com",
|
||||
uname: "John Doe",
|
||||
},
|
||||
{
|
||||
name: "email with special characters",
|
||||
email: "user+tag@sub.example.com",
|
||||
uname: "O'Brien, Mary-Jane",
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
email: "user@example.com",
|
||||
uname: "Jean-Pierre Müller 日本語",
|
||||
},
|
||||
{
|
||||
name: "long values",
|
||||
email: "very.long.email.address.that.is.quite.extended@subdomain.example.organization.com",
|
||||
uname: "A Very Long Name That Contains Many Words And Is Quite Extended For Testing Purposes",
|
||||
},
|
||||
{
|
||||
name: "empty email only",
|
||||
email: "",
|
||||
uname: "Name Only",
|
||||
},
|
||||
{
|
||||
name: "empty name only",
|
||||
email: "email@only.com",
|
||||
uname: "",
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
email: "",
|
||||
uname: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "test-user",
|
||||
Email: tc.email,
|
||||
Name: tc.uname,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tc.email != "" {
|
||||
assert.NotEqual(t, tc.email, user.Email, "email should be encrypted")
|
||||
}
|
||||
if tc.uname != "" {
|
||||
assert.NotEqual(t, tc.uname, user.Name, "name should be encrypted")
|
||||
}
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.email, user.Email, "decrypted email should match original")
|
||||
assert.Equal(t, tc.uname, user.Name, "decrypted name should match original")
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user