diff --git a/management/cmd/management.go b/management/cmd/management.go index ca333b931..545d6840c 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -149,7 +149,7 @@ var ( } if key != "" { - log.Debugf("update config with activity store key") + log.Infof("update config with activity store key") config.DataStoreEncryptionKey = key err := updateMgmtConfig(mgmtConfig, config) if err != nil { @@ -466,7 +466,7 @@ func loadMgmtConfig(mgmtConfigPath string) (*server.Config, error) { } func updateMgmtConfig(path string, config *server.Config) error { - return util.WriteJson(path, config) + return util.DirectWriteJson(path, config) } // OIDCConfigResponse used for parsing OIDC config response diff --git a/management/server/activity/event.go b/management/server/activity/event.go index 1bf86ef2c..f212f5b21 100644 --- a/management/server/activity/event.go +++ b/management/server/activity/event.go @@ -18,7 +18,9 @@ type Event struct { ID uint64 // InitiatorID is the ID of an object that initiated the event (e.g., a user) InitiatorID string - // InitiatorEmail is the email address of an object that initiated the event. This will be set on deleted users only + // InitiatorName is the name of an object that initiated the event. + InitiatorName string + // InitiatorEmail is the email address of an object that initiated the event. InitiatorEmail string // TargetID is the ID of an object that was effected by the event (e.g., a peer) TargetID string @@ -42,6 +44,7 @@ func (e *Event) Copy() *Event { Activity: e.Activity, ID: e.ID, InitiatorID: e.InitiatorID, + InitiatorName: e.InitiatorName, InitiatorEmail: e.InitiatorEmail, TargetID: e.TargetID, AccountID: e.AccountID, diff --git a/management/server/activity/sqlite/crypt.go b/management/server/activity/sqlite/crypt.go index 8f2755604..cf4dda746 100644 --- a/management/server/activity/sqlite/crypt.go +++ b/management/server/activity/sqlite/crypt.go @@ -11,7 +11,7 @@ import ( var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} -type EmailEncrypt struct { +type FieldEncrypt struct { block cipher.Block } @@ -25,7 +25,7 @@ func GenerateKey() (string, error) { return readableKey, nil } -func NewEmailEncrypt(key string) (*EmailEncrypt, error) { +func NewFieldEncrypt(key string) (*FieldEncrypt, error) { binKey, err := base64.StdEncoding.DecodeString(key) if err != nil { return nil, err @@ -35,14 +35,14 @@ func NewEmailEncrypt(key string) (*EmailEncrypt, error) { if err != nil { return nil, err } - ec := &EmailEncrypt{ + ec := &FieldEncrypt{ block: block, } return ec, nil } -func (ec *EmailEncrypt) Encrypt(payload string) string { +func (ec *FieldEncrypt) Encrypt(payload string) string { plainText := pkcs5Padding([]byte(payload)) cipherText := make([]byte, len(plainText)) cbc := cipher.NewCBCEncrypter(ec.block, iv) @@ -50,7 +50,7 @@ func (ec *EmailEncrypt) Encrypt(payload string) string { return base64.StdEncoding.EncodeToString(cipherText) } -func (ec *EmailEncrypt) Decrypt(data string) (string, error) { +func (ec *FieldEncrypt) Decrypt(data string) (string, error) { cipherText, err := base64.StdEncoding.DecodeString(data) if err != nil { return "", err diff --git a/management/server/activity/sqlite/crypt_test.go b/management/server/activity/sqlite/crypt_test.go index 5fb59a692..efa740921 100644 --- a/management/server/activity/sqlite/crypt_test.go +++ b/management/server/activity/sqlite/crypt_test.go @@ -10,7 +10,7 @@ func TestGenerateKey(t *testing.T) { if err != nil { t.Fatalf("failed to generate key: %s", err) } - ee, err := NewEmailEncrypt(key) + ee, err := NewFieldEncrypt(key) if err != nil { t.Fatalf("failed to init email encryption: %s", err) } @@ -36,7 +36,7 @@ func TestCorruptKey(t *testing.T) { if err != nil { t.Fatalf("failed to generate key: %s", err) } - ee, err := NewEmailEncrypt(key) + ee, err := NewFieldEncrypt(key) if err != nil { t.Fatalf("failed to init email encryption: %s", err) } @@ -51,13 +51,13 @@ func TestCorruptKey(t *testing.T) { t.Fatalf("failed to generate key: %s", err) } - ee, err = NewEmailEncrypt(newKey) + ee, err = NewFieldEncrypt(newKey) if err != nil { t.Fatalf("failed to init email encryption: %s", err) } - res, err := ee.Decrypt(encrypted) - if err == nil || res == testData { + res, _ := ee.Decrypt(encrypted) + if res == testData { t.Fatalf("incorrect decryption, the result is: %s", res) } } diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go index 7ff59674d..6af4d4d8d 100644 --- a/management/server/activity/sqlite/sqlite.go +++ b/management/server/activity/sqlite/sqlite.go @@ -7,7 +7,7 @@ import ( "path/filepath" "time" - _ "github.com/mattn/go-sqlite3" // sqlite driver + _ "github.com/mattn/go-sqlite3" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" @@ -25,16 +25,16 @@ const ( "meta TEXT," + " target_id TEXT);" - creatTableAccountEmailQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL);` + creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);` - selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta + selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta FROM events LEFT JOIN deleted_users i ON events.initiator_id = i.id LEFT JOIN deleted_users t ON events.target_id = t.id WHERE account_id = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?;` - selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta + selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta FROM events LEFT JOIN deleted_users i ON events.initiator_id = i.id LEFT JOIN deleted_users t ON events.target_id = t.id @@ -44,13 +44,13 @@ const ( insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " + "VALUES(?, ?, ?, ?, ?, ?)" - insertDeleteUserQuery = `INSERT INTO deleted_users(id, email) VALUES(?, ?)` + insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)` ) // Store is the implementation of the activity.Store interface backed by SQLite type Store struct { db *sql.DB - emailEncrypt *EmailEncrypt + fieldEncrypt *FieldEncrypt insertStatement *sql.Stmt selectAscStatement *sql.Stmt @@ -66,49 +66,63 @@ func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) { return nil, err } - crypt, err := NewEmailEncrypt(encryptionKey) + crypt, err := NewFieldEncrypt(encryptionKey) if err != nil { + _ = db.Close() return nil, err } _, err = db.Exec(createTableQuery) if err != nil { + _ = db.Close() return nil, err } - _, err = db.Exec(creatTableAccountEmailQuery) + _, err = db.Exec(creatTableDeletedUsersQuery) if err != nil { + _ = db.Close() + return nil, err + } + + err = updateDeletedUsersTable(db) + if err != nil { + _ = db.Close() return nil, err } insertStmt, err := db.Prepare(insertQuery) if err != nil { + _ = db.Close() return nil, err } selectDescStmt, err := db.Prepare(selectDescQuery) if err != nil { + _ = db.Close() return nil, err } selectAscStmt, err := db.Prepare(selectAscQuery) if err != nil { + _ = db.Close() return nil, err } deleteUserStmt, err := db.Prepare(insertDeleteUserQuery) if err != nil { + _ = db.Close() return nil, err } s := &Store{ db: db, - emailEncrypt: crypt, + fieldEncrypt: crypt, insertStatement: insertStmt, selectDescStatement: selectDescStmt, selectAscStatement: selectAscStmt, deleteUserStmt: deleteUserStmt, } + return s, nil } @@ -119,12 +133,14 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { var operation activity.Activity var timestamp time.Time var initiator string + var initiatorName *string var initiatorEmail *string var target string + var targetUserName *string var targetEmail *string var account string var jsonMeta string - err := result.Scan(&id, &operation, ×tamp, &initiator, &initiatorEmail, &target, &targetEmail, &account, &jsonMeta) + err := result.Scan(&id, &operation, ×tamp, &initiator, &initiatorName, &initiatorEmail, &target, &targetUserName, &targetEmail, &account, &jsonMeta) if err != nil { return nil, err } @@ -137,8 +153,18 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { } } + if targetUserName != nil { + name, err := store.fieldEncrypt.Decrypt(*targetUserName) + if err != nil { + log.Errorf("failed to decrypt username for target id: %s", target) + meta["username"] = "" + } else { + meta["username"] = name + } + } + if targetEmail != nil { - email, err := store.emailEncrypt.Decrypt(*targetEmail) + email, err := store.fieldEncrypt.Decrypt(*targetEmail) if err != nil { log.Errorf("failed to decrypt email address for target id: %s", target) meta["email"] = "" @@ -157,8 +183,17 @@ func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) { Meta: meta, } + if initiatorName != nil { + name, err := store.fieldEncrypt.Decrypt(*initiatorName) + if err != nil { + log.Errorf("failed to decrypt username of initiator: %s", initiator) + } else { + event.InitiatorName = name + } + } + if initiatorEmail != nil { - email, err := store.emailEncrypt.Decrypt(*initiatorEmail) + email, err := store.fieldEncrypt.Decrypt(*initiatorEmail) if err != nil { log.Errorf("failed to decrypt email address of initiator: %s", initiator) } else { @@ -191,7 +226,7 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([ // Save an event in the SQLite events table end encrypt the "email" element in meta map func (store *Store) Save(event *activity.Event) (*activity.Event, error) { var jsonMeta string - meta, err := store.saveDeletedUserEmailInEncrypted(event) + meta, err := store.saveDeletedUserEmailAndNameInEncrypted(event) if err != nil { return nil, err } @@ -219,26 +254,31 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) { return eventCopy, nil } -// saveDeletedUserEmailInEncrypted if the meta contains email then store it in encrypted way and delete this item from -// meta map -func (store *Store) saveDeletedUserEmailInEncrypted(event *activity.Event) (map[string]any, error) { +// saveDeletedUserEmailAndNameInEncrypted if the meta contains email and name then store it in encrypted way and delete +// this item from meta map +func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event) (map[string]any, error) { email, ok := event.Meta["email"] if !ok { return event.Meta, nil } - delete(event.Meta, "email") + name, ok := event.Meta["name"] + if !ok { + return event.Meta, nil + } - encrypted := store.emailEncrypt.Encrypt(fmt.Sprintf("%s", email)) - _, err := store.deleteUserStmt.Exec(event.TargetID, encrypted) + encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email)) + encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name)) + _, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName) if err != nil { return nil, err } - if len(event.Meta) == 1 { + if len(event.Meta) == 2 { return nil, nil // nolint } delete(event.Meta, "email") + delete(event.Meta, "name") return event.Meta, nil } @@ -249,3 +289,44 @@ func (store *Store) Close() error { } return nil } + +func updateDeletedUsersTable(db *sql.DB) error { + log.Debugf("check deleted_users table version") + rows, err := db.Query(`PRAGMA table_info(deleted_users);`) + if err != nil { + return err + } + defer rows.Close() + found := false + for rows.Next() { + var ( + cid int + name string + dataType string + notNull int + dfltVal sql.NullString + pk int + ) + err := rows.Scan(&cid, &name, &dataType, ¬Null, &dfltVal, &pk) + if err != nil { + return err + } + if name == "name" { + found = true + break + } + } + + err = rows.Err() + if err != nil { + return err + } + + if found { + return nil + } + + log.Debugf("update delted_users table") + _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`) + return err +} diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index f2d1e26bf..db2561481 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -922,6 +922,10 @@ components: description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event. type: string example: google-oauth2|123456789012345678901 + initiator_name: + description: The name of the initiator of the event. + type: string + example: John Doe initiator_email: description: The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event. type: string @@ -942,6 +946,7 @@ components: - activity - activity_code - initiator_id + - initiator_name - initiator_email - target_id - meta diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 33c935a68..dec73630d 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -170,6 +170,9 @@ type Event struct { // InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event. InitiatorId string `json:"initiator_id"` + // InitiatorName The name of the initiator of the event. + InitiatorName string `json:"initiator_name"` + // Meta The metadata of the event Meta map[string]string `json:"meta"` diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go index cbca44364..a89c206a3 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/events_handler.go @@ -50,7 +50,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { events[i] = toEventResponse(e) } - err = h.fillEventsWithInitiatorEmail(events, account.Id, user.Id) + err = h.fillEventsWithUserInfo(events, account.Id, user.Id) if err != nil { util.WriteError(err, w) return @@ -59,8 +59,8 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(w, events) } -func (h *EventsHandler) fillEventsWithInitiatorEmail(events []*api.Event, accountId, userId string) error { - // build email map based on users +func (h *EventsHandler) fillEventsWithUserInfo(events []*api.Event, accountId, userId string) error { + // build email, name maps based on users userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId) if err != nil { log.Errorf("failed to get users from account: %s", err) @@ -68,19 +68,39 @@ func (h *EventsHandler) fillEventsWithInitiatorEmail(events []*api.Event, accoun } emails := make(map[string]string) + names := make(map[string]string) for _, ui := range userInfos { emails[ui.ID] = ui.Email + names[ui.ID] = ui.Name } - // fill event with email of initiator var ok bool for _, event := range events { + // fill initiator if event.InitiatorEmail == "" { event.InitiatorEmail, ok = emails[event.InitiatorId] if !ok { log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId) } } + + if event.InitiatorName == "" { + // here to allowed to be empty because in the first release we did not store the name + event.InitiatorName = names[event.InitiatorId] + } + + // fill target meta + email, ok := emails[event.TargetId] + if !ok { + continue + } + event.Meta["email"] = email + + username, ok := names[event.TargetId] + if !ok { + continue + } + event.Meta["username"] = username } return nil } @@ -95,6 +115,7 @@ func toEventResponse(event *activity.Event) *api.Event { e := &api.Event{ Id: fmt.Sprint(event.ID), InitiatorId: event.InitiatorID, + InitiatorName: event.InitiatorName, InitiatorEmail: event.InitiatorEmail, Activity: event.Activity.Message(), ActivityCode: api.EventActivityCode(event.Activity.StringCode()), diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 44f4919f5..277627310 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -77,6 +77,7 @@ func WriteErrorResponse(errMsg string, httpStatus int, w http.ResponseWriter) { // WriteError converts an error to an JSON error response. // If it is known internal error of type server.Error then it sets the messages from the error, a generic message otherwise func WriteError(err error, w http.ResponseWriter) { + log.Errorf("got a handler error: %s", err.Error()) errStatus, ok := status.FromError(err) httpStatus := http.StatusInternalServerError msg := "internal server error" diff --git a/management/server/idp/idp.go b/management/server/idp/idp.go index 7e1064da1..76eaade74 100644 --- a/management/server/idp/idp.go +++ b/management/server/idp/idp.go @@ -46,10 +46,10 @@ type Config struct { ManagerType string ClientConfig *ClientConfig ExtraConfig ExtraConfig - Auth0ClientCredentials Auth0ClientConfig - AzureClientCredentials AzureClientConfig - KeycloakClientCredentials KeycloakClientConfig - ZitadelClientCredentials ZitadelClientConfig + Auth0ClientCredentials *Auth0ClientConfig + AzureClientCredentials *AzureClientConfig + KeycloakClientCredentials *KeycloakClientConfig + ZitadelClientCredentials *ZitadelClientConfig } // ManagerCredentials interface that authenticates using the credential of each type of idp @@ -105,7 +105,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) case "auth0": auth0ClientConfig := config.Auth0ClientCredentials if config.ClientConfig != nil { - auth0ClientConfig = Auth0ClientConfig{ + auth0ClientConfig = &Auth0ClientConfig{ Audience: config.ExtraConfig["Audience"], AuthIssuer: config.ClientConfig.Issuer, ClientID: config.ClientConfig.ClientID, @@ -114,11 +114,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) } } - return NewAuth0Manager(auth0ClientConfig, appMetrics) + return NewAuth0Manager(*auth0ClientConfig, appMetrics) case "azure": azureClientConfig := config.AzureClientCredentials if config.ClientConfig != nil { - azureClientConfig = AzureClientConfig{ + azureClientConfig = &AzureClientConfig{ ClientID: config.ClientConfig.ClientID, ClientSecret: config.ClientConfig.ClientSecret, GrantType: config.ClientConfig.GrantType, @@ -128,11 +128,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) } } - return NewAzureManager(azureClientConfig, appMetrics) + return NewAzureManager(*azureClientConfig, appMetrics) case "keycloak": keycloakClientConfig := config.KeycloakClientCredentials if config.ClientConfig != nil { - keycloakClientConfig = KeycloakClientConfig{ + keycloakClientConfig = &KeycloakClientConfig{ ClientID: config.ClientConfig.ClientID, ClientSecret: config.ClientConfig.ClientSecret, GrantType: config.ClientConfig.GrantType, @@ -141,11 +141,11 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) } } - return NewKeycloakManager(keycloakClientConfig, appMetrics) + return NewKeycloakManager(*keycloakClientConfig, appMetrics) case "zitadel": zitadelClientConfig := config.ZitadelClientCredentials if config.ClientConfig != nil { - zitadelClientConfig = ZitadelClientConfig{ + zitadelClientConfig = &ZitadelClientConfig{ ClientID: config.ClientConfig.ClientID, ClientSecret: config.ClientConfig.ClientSecret, GrantType: config.ClientConfig.GrantType, @@ -154,7 +154,7 @@ func NewManager(config Config, appMetrics telemetry.AppMetrics) (Manager, error) } } - return NewZitadelManager(zitadelClientConfig, appMetrics) + return NewZitadelManager(*zitadelClientConfig, appMetrics) case "authentik": authentikConfig := AuthentikClientConfig{ Issuer: config.ClientConfig.Issuer, diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 73958a69e..e42b9a506 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -13,8 +13,9 @@ import ( "time" "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server/telemetry" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/telemetry" ) // ZitadelManager zitadel manager client instance. @@ -428,7 +429,7 @@ func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMe return err } - resource := fmt.Sprintf("users/%s", userID) + resource := fmt.Sprintf("users/%s/metadata/_bulk", userID) _, err = zm.post(resource, string(payload)) if err != nil { return err diff --git a/management/server/user.go b/management/server/user.go index ebebe1e0f..23c6f228d 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -309,6 +309,9 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( // DeleteUser deletes a user from the given account. func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error { + if initiatorUserID == targetUserID { + return status.Errorf(status.InvalidArgument, "self deletion is not allowed") + } unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -340,7 +343,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t return err } - targetUserEmail, err := am.getEmailOfTargetUser(account.Id, initiatorUserID, targetUserID) + tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(account.Id, initiatorUserID, targetUserID) if err != nil { log.Errorf("failed to resolve email address: %s", err) return err @@ -352,15 +355,15 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t meta = map[string]any{"name": targetUser.ServiceUserName} eventAction = activity.ServiceUserDeleted } else { - meta = map[string]any{"email": targetUserEmail} + meta = map[string]any{"name": tuName, "email": tuEmail} eventAction = activity.UserDeleted - } am.storeEvent(initiatorUserID, targetUserID, accountID, eventAction, meta) - if !isNil(am.idpManager) { + if !targetUser.IsServiceUser && !isNil(am.idpManager) { err := am.deleteUserFromIDP(targetUserID, accountID) if err != nil { + log.Debugf("failed to delete user from IDP: %s", targetUserID) return err } } @@ -876,18 +879,18 @@ func (am *DefaultAccountManager) deleteUserFromIDP(targetUserID, accountID strin return nil } -func (am *DefaultAccountManager) getEmailOfTargetUser(accountId string, initiatorId, targetId string) (string, error) { +func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(accountId, initiatorId, targetId string) (string, string, error) { userInfos, err := am.GetUsersFromAccount(accountId, initiatorId) if err != nil { - return "", err + return "", "", err } for _, ui := range userInfos { if ui.ID == targetId { - return ui.Email, nil + return ui.Email, ui.Name, nil } } - return "", fmt.Errorf("email not found for user: %s", targetId) + return "", "", fmt.Errorf("user info not found for user: %s", targetId) } func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { diff --git a/management/server/user_test.go b/management/server/user_test.go index bd64074b9..1565814b8 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -424,7 +424,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { assert.Nil(t, store.Accounts[mockAccountID].Users[mockServiceUserID]) } -func TestUser_DeleteUser_regularUser(t *testing.T) { +func TestUser_DeleteUser_SelfDelete(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") @@ -439,6 +439,32 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } err = am.DeleteUser(mockAccountID, mockUserID, mockUserID) + if err == nil { + t.Fatalf("failed to prevent self deletion") + } +} + +func TestUser_DeleteUser_regularUser(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + targetId := "user2" + account.Users[targetId] = &User{ + Id: targetId, + IsServiceUser: true, + ServiceUserName: "user2username", + } + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + eventStore: &activity.InMemoryEventStore{}, + } + + err = am.DeleteUser(mockAccountID, mockUserID, targetId) if err != nil { t.Errorf("unexpected error: %s", err) } diff --git a/util/file.go b/util/file.go index 022841947..0cbfa37ab 100644 --- a/util/file.go +++ b/util/file.go @@ -5,6 +5,8 @@ import ( "io" "os" "path/filepath" + + log "github.com/sirupsen/logrus" ) // WriteJson writes JSON config object to a file creating parent directories if required @@ -54,6 +56,68 @@ func WriteJson(file string, obj interface{}) error { return nil } +// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file +func DirectWriteJson(file string, obj interface{}) error { + + _, _, err := prepareConfigFileDir(file) + if err != nil { + return err + } + + targetFile, err := openOrCreateFile(file) + if err != nil { + return err + } + + defer func() { + err = targetFile.Close() + if err != nil { + log.Errorf("failed to close file %s: %v", file, err) + } + }() + + // make it pretty + bs, err := json.MarshalIndent(obj, "", " ") + if err != nil { + return err + } + + err = targetFile.Truncate(0) + if err != nil { + return err + } + + _, err = targetFile.Write(bs) + if err != nil { + return err + } + + return nil +} + +func openOrCreateFile(file string) (*os.File, error) { + s, err := os.Stat(file) + if err == nil { + return os.OpenFile(file, os.O_WRONLY, s.Mode()) + } + + if !os.IsNotExist(err) { + return nil, err + } + + targetFile, err := os.Create(file) + if err != nil { + return nil, err + } + //no:lint + err = targetFile.Chmod(0640) + if err != nil { + _ = targetFile.Close() + return nil, err + } + return targetFile, nil +} + // ReadJson reads JSON config file and maps to a provided interface func ReadJson(file string, res interface{}) (interface{}, error) {