diff --git a/management/server/account.go b/management/server/account.go index 938376533..4807a0600 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -49,16 +49,16 @@ type AccountManager interface { CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey, userID string) (*SetupKey, error) - CreateUser(accountID, executingUserID string, key *UserInfo) (*UserInfo, error) - DeleteUser(accountID, executingUserID string, targetUserID string) error + CreateUser(accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) + DeleteUser(accountID, initiatorUserID string, targetUserID string) error ListSetupKeys(accountID, userID string) ([]*SetupKey, error) - SaveUser(accountID, userID string, update *User) (*UserInfo, error) + SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) MarkPATUsed(tokenID string) error - IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) + GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) AccountExists(accountId string) (*bool, error) GetPeerByKey(peerKey string) (*Peer, error) GetPeers(accountID, userID string) ([]*Peer, error) @@ -69,10 +69,10 @@ type AccountManager interface { GetNetworkMap(peerID string) (*NetworkMap, error) GetPeerNetwork(peerID string) (*Network, error) AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error) - CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) - DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error - GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) - GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) + CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error + GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) + GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error) @@ -179,6 +179,7 @@ type UserInfo struct { AutoGroups []string `json:"auto_groups"` Status string `json:"-"` IsServiceUser bool `json:"is_service_user"` + IsBlocked bool `json:"is_blocked"` } // getRoutesToSync returns the enabled routes for the peer ID and the routes diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index a3396a2f8..03ec709e5 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -91,6 +91,10 @@ const ( ServiceUserCreated // ServiceUserDeleted indicates that a user deleted a service user ServiceUserDeleted + // UserBlocked indicates that a user blocked another user + UserBlocked + // UserUnblocked indicates that a user unblocked another user + UserUnblocked ) const ( @@ -184,6 +188,10 @@ const ( ServiceUserCreatedMessage string = "Service user created" // ServiceUserDeletedMessage is a human-readable text message of the ServiceUserDeleted activity ServiceUserDeletedMessage string = "Service user deleted" + // UserBlockedMessage is a human-readable text message of the UserBlocked activity + UserBlockedMessage string = "User blocked" + // UserUnblockedMessage is a human-readable text message of the UserUnblocked activity + UserUnblockedMessage string = "User unblocked" ) // Activity that triggered an Event @@ -282,6 +290,10 @@ func (a Activity) Message() string { return ServiceUserCreatedMessage case ServiceUserDeleted: return ServiceUserDeletedMessage + case UserBlocked: + return UserBlockedMessage + case UserUnblocked: + return UserUnblockedMessage default: return "UNKNOWN_ACTIVITY" } @@ -300,6 +312,10 @@ func (a Activity) StringCode() string { return "user.join" case UserInvited: return "user.invite" + case UserBlocked: + return "user.block" + case UserUnblocked: + return "user.unblock" case AccountCreated: return "account.create" case RuleAdded: diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 449000cb5..45866c877 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -65,7 +65,7 @@ components: status: description: User's status type: string - enum: [ "active","invited","disabled" ] + enum: [ "active","invited","blocked" ] auto_groups: description: Groups to auto-assign to peers registered by this user type: array @@ -79,6 +79,9 @@ components: description: Is true if this user is a service user type: boolean readOnly: true + is_blocked: + description: Is true if this user is blocked. Blocked users can't use the system + type: boolean required: - id - email @@ -86,6 +89,7 @@ components: - role - auto_groups - status + - is_blocked UserRequest: type: object properties: @@ -97,9 +101,13 @@ components: type: array items: type: string + is_blocked: + description: If set to true then user is blocked and can't use the system + type: boolean required: - role - auto_groups + - is_blocked UserCreateRequest: type: object properties: @@ -645,7 +653,7 @@ components: description: The string code of the activity that occurred during the event type: string enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete", - "user.role.update", + "user.role.update", "user.block", "user.unblock", "setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse", "setupkey.group.delete", "setupkey.group.add", "rule.add", "rule.delete", "rule.update", diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index c1da0cf3d..cb1b40c60 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -46,6 +46,7 @@ const ( EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add" EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke" EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update" + EventActivityCodeUserBlock EventActivityCode = "user.block" EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add" EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete" EventActivityCodeUserInvite EventActivityCode = "user.invite" @@ -53,6 +54,7 @@ const ( EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add" EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete" EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update" + EventActivityCodeUserUnblock EventActivityCode = "user.unblock" ) // Defines values for NameserverNsType. @@ -68,9 +70,9 @@ const ( // Defines values for UserStatus. const ( - UserStatusActive UserStatus = "active" - UserStatusDisabled UserStatus = "disabled" - UserStatusInvited UserStatus = "invited" + UserStatusActive UserStatus = "active" + UserStatusBlocked UserStatus = "blocked" + UserStatusInvited UserStatus = "invited" ) // Account defines model for Account. @@ -552,6 +554,9 @@ type User struct { // Id User ID Id string `json:"id"` + // IsBlocked Is true if this user is blocked. Blocked users can't use the system + IsBlocked bool `json:"is_blocked"` + // IsCurrent Is true if authenticated user is the same as this user IsCurrent *bool `json:"is_current,omitempty"` @@ -594,6 +599,9 @@ type UserRequest struct { // AutoGroups Groups to auto-assign to peers registered by this user AutoGroups []string `json:"auto_groups"` + // IsBlocked If set to true then user is blocked and can't use the system + IsBlocked bool `json:"is_blocked"` + // Role User's NetBird account role Role string `json:"role"` } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 5fa0aea5b..79b8e37b3 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -43,7 +43,7 @@ func APIHandler(accountManager s.AccountManager, jwtValidator jwtclaims.JWTValid acMiddleware := middleware.NewAccessControl( authCfg.Audience, authCfg.UserIDClaim, - accountManager.IsUserAdmin) + accountManager.GetUser) rootRouter := mux.NewRouter() metricsMiddleware := appMetrics.HTTPMiddleware() diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 5f8389dfa..fc362ea80 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -6,28 +6,30 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) +// GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims +type GetUser func(claims jwtclaims.AuthorizationClaims) (*server.User, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { - isUserAdmin IsUserAdminFunc claimsExtract jwtclaims.ClaimsExtractor + getUser GetUser } // NewAccessControl instance constructor -func NewAccessControl(audience, userIDClaim string, isUserAdmin IsUserAdminFunc) *AccessControl { +func NewAccessControl(audience, userIDClaim string, getUser GetUser) *AccessControl { return &AccessControl{ - isUserAdmin: isUserAdmin, claimsExtract: *jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(audience), jwtclaims.WithUserIDClaim(userIDClaim), ), + getUser: getUser, } } @@ -37,23 +39,29 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims := a.claimsExtract.FromRequestContext(r) - ok, err := a.isUserAdmin(claims) + user, err := a.getUser(claims) if err != nil { util.WriteError(status.Errorf(status.Unauthorized, "invalid JWT"), w) return } - if !ok { + + if user.IsBlocked() { + util.WriteError(status.Errorf(status.PermissionDenied, "the user has no access to the API or is blocked"), w) + return + } + + if !user.IsAdmin() { switch r.Method { case http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodPut: ok, err := regexp.MatchString(`^.*/api/users/.*/tokens.*$`, r.URL.Path) if err != nil { - log.Debugf("Regex failed") + log.Debugf("regex failed") util.WriteError(status.Errorf(status.Internal, ""), w) return } if ok { - log.Debugf("Valid Path") + log.Debugf("valid Path") h.ServeHTTP(w, r) return } diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index 6a1ea8312..45fda0a55 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -63,7 +63,7 @@ var testAccount = &server.Account{ func initPATTestData() *PATHandler { return &PATHandler{ accountManager: &mock_server.MockAccountManager{ - CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + CreatePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -79,7 +79,7 @@ func initPATTestData() *PATHandler { GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { return testAccount, testAccount.Users[existingUserID], nil }, - DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error { + DeletePATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if accountID != existingAccountID { return status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -91,7 +91,7 @@ func initPATTestData() *PATHandler { } return nil }, - GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + GetPATFunc: func(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -103,7 +103,7 @@ func initPATTestData() *PATHandler { } return testAccount.Users[existingUserID].PATs[existingTokenID], nil }, - GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + GetAllPATsFunc: func(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 475035959..4a9a4a423 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -61,6 +61,11 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } + if req.AutoGroups == nil { + util.WriteErrorResponse("auto_groups field can't be absent", http.StatusBadRequest, w) + return + } + userRole := server.StrRoleToUserRole(req.Role) if userRole == server.UserRoleUnknown { util.WriteError(status.Errorf(status.InvalidArgument, "invalid user role"), w) @@ -71,7 +76,9 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { Id: userID, Role: userRole, AutoGroups: req.AutoGroups, + Blocked: req.IsBlocked, }) + if err != nil { util.WriteError(err, w) return @@ -214,7 +221,11 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { case "invited": userStatus = api.UserStatusInvited default: - userStatus = api.UserStatusDisabled + userStatus = api.UserStatusBlocked + } + + if user.IsBlocked { + userStatus = api.UserStatusBlocked } isCurrent := user.ID == currenUserID @@ -227,5 +238,6 @@ func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { Status: userStatus, IsCurrent: &isCurrent, IsServiceUser: &user.IsServiceUser, + IsBlocked: user.IsBlocked, } } diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index 1dc1453bd..c6d7dd4c6 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -3,6 +3,7 @@ package http import ( "bytes" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -31,16 +32,19 @@ var usersTestAccount = &server.Account{ Id: existingUserID, Role: "admin", IsServiceUser: false, + AutoGroups: []string{"group_1"}, }, regularUserID: { Id: regularUserID, Role: "user", IsServiceUser: false, + AutoGroups: []string{"group_1"}, }, serviceUserID: { Id: serviceUserID, Role: "user", IsServiceUser: true, + AutoGroups: []string{"group_1"}, }, }, } @@ -70,7 +74,7 @@ func initUsersTestData() *UsersHandler { } return key, nil }, - DeleteUserFunc: func(accountID string, executingUserID string, targetUserID string) error { + DeleteUserFunc: func(accountID string, initiatorUserID string, targetUserID string) error { if targetUserID == notFoundUserID { return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID) } @@ -79,6 +83,21 @@ func initUsersTestData() *UsersHandler { } return nil }, + SaveUserFunc: func(accountID, userID string, update *server.User) (*server.UserInfo, error) { + if update.Id == notFoundUserID { + return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id) + } + + if userID != existingUserID { + return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) + } + + info, err := update.Copy().ToUserInfo(nil) + if err != nil { + return nil, err + } + return info, nil + }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { @@ -145,6 +164,122 @@ func TestGetUsers(t *testing.T) { } } +func TestUpdateUser(t *testing.T) { + tt := []struct { + name string + expectedStatusCode int + requestType string + requestPath string + requestBody io.Reader + expectedUserID string + expectedRole string + expectedStatus string + expectedBlocked bool + expectedIsServiceUser bool + expectedGroups []string + }{ + { + name: "Update_Block_User", + requestType: http.MethodPut, + requestPath: "/api/users/" + regularUserID, + expectedStatusCode: http.StatusOK, + expectedUserID: regularUserID, + expectedBlocked: true, + expectedRole: "user", + expectedStatus: "blocked", + expectedGroups: []string{"group_1"}, + requestBody: bytes.NewBufferString("{\"role\":\"user\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": true}"), + }, + { + name: "Update_Change_Role_To_Admin", + requestType: http.MethodPut, + requestPath: "/api/users/" + regularUserID, + expectedStatusCode: http.StatusOK, + expectedUserID: regularUserID, + expectedBlocked: false, + expectedRole: "admin", + expectedStatus: "blocked", + expectedGroups: []string{"group_1"}, + requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": false}"), + }, + { + name: "Update_Groups", + requestType: http.MethodPut, + requestPath: "/api/users/" + regularUserID, + expectedStatusCode: http.StatusOK, + expectedUserID: regularUserID, + expectedBlocked: false, + expectedRole: "admin", + expectedStatus: "blocked", + expectedGroups: []string{"group_2", "group_3"}, + requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_3\", \"group_2\"],\"is_service_user\":false, \"is_blocked\": false}"), + }, + { + name: "Should_Fail_Because_AutoGroups_Is_Absent", + requestType: http.MethodPut, + requestPath: "/api/users/" + regularUserID, + expectedStatusCode: http.StatusBadRequest, + expectedUserID: regularUserID, + expectedBlocked: false, + expectedRole: "admin", + expectedStatus: "blocked", + expectedGroups: []string{"group_2", "group_3"}, + requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"is_service_user\":false, \"is_blocked\": false}"), + }, + } + + userHandler := initUsersTestData() + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + + router := mux.NewRouter() + router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + if status := recorder.Code; status != tc.expectedStatusCode { + t.Fatalf("handler returned wrong status code: got %v want %v", + status, http.StatusOK) + } + + if tc.expectedStatusCode == 200 { + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("I don't know what I expected; %v", err) + } + + respBody := &api.User{} + err = json.Unmarshal(content, &respBody) + if err != nil { + t.Fatalf("response content is not in correct json format; %v", err) + } + + assert.Equal(t, tc.expectedUserID, respBody.Id) + assert.Equal(t, tc.expectedRole, respBody.Role) + assert.Equal(t, tc.expectedIsServiceUser, *respBody.IsServiceUser) + assert.Equal(t, tc.expectedBlocked, respBody.IsBlocked) + assert.Len(t, respBody.AutoGroups, len(tc.expectedGroups)) + + for _, expectedGroup := range tc.expectedGroups { + exists := false + for _, actualGroup := range respBody.AutoGroups { + if expectedGroup == actualGroup { + exists = true + } + } + assert.True(t, exists, fmt.Sprintf("group %s not found in the response", expectedGroup)) + } + } + }) + } +} + func TestCreateUser(t *testing.T) { name := "name" email := "email" diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 5d4eb709f..33f2a81f3 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -19,7 +19,7 @@ type MockAccountManager struct { expiresIn time.Duration, autoGroups []string, usageLimit int, userID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) - IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) + GetUserFunc func(claims jwtclaims.AuthorizationClaims) (*server.User, error) AccountExistsFunc func(accountId string) (*bool, error) GetPeerByKeyFunc func(peerKey string) (*server.Peer, error) GetPeersFunc func(accountID, userID string) ([]*server.Peer, error) @@ -60,11 +60,11 @@ type MockAccountManager struct { SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) - DeleteUserFunc func(accountID string, executingUserID string, targetUserID string) error - CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) - DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error - GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) - GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) + DeleteUserFunc func(accountID string, initiatorUserID string, targetUserID string) error + CreatePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + DeletePATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) error + GetPATFunc func(accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) + GetAllPATsFunc func(accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error @@ -190,33 +190,33 @@ func (am *MockAccountManager) MarkPATUsed(pat string) error { } // CreatePAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { +func (am *MockAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { if am.CreatePATFunc != nil { - return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn) + return am.CreatePATFunc(accountID, initiatorUserID, targetUserID, name, expiresIn) } return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented") } // DeletePAT mock implementation of DeletePAT from server.AccountManager interface -func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error { +func (am *MockAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if am.DeletePATFunc != nil { - return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID) + return am.DeletePATFunc(accountID, initiatorUserID, targetUserID, tokenID) } return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented") } // GetPAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { if am.GetPATFunc != nil { - return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID) + return am.GetPATFunc(accountID, initiatorUserID, targetUserID, tokenID) } return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented") } // GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface -func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { if am.GetAllPATsFunc != nil { - return am.GetAllPATsFunc(accountID, executingUserID, targetUserID) + return am.GetAllPATsFunc(accountID, initiatorUserID, targetUserID) } return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented") } @@ -385,12 +385,12 @@ func (am *MockAccountManager) UpdatePeerMeta(peerID string, meta server.PeerSyst return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented") } -// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface -func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { - if am.IsUserAdminFunc != nil { - return am.IsUserAdminFunc(claims) +// GetUser mock implementation of GetUser from server.AccountManager interface +func (am *MockAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*server.User, error) { + if am.GetUserFunc != nil { + return am.GetUserFunc(claims) } - return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method IsUserGetUserAdmin is not implemented") } // UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager @@ -493,9 +493,9 @@ func (am *MockAccountManager) SaveUser(accountID, userID string, user *server.Us } // DeleteUser mocks DeleteUser of the AccountManager interface -func (am *MockAccountManager) DeleteUser(accountID string, executingUserID string, targetUserID string) error { +func (am *MockAccountManager) DeleteUser(accountID string, initiatorUserID string, targetUserID string) error { if am.DeleteUserFunc != nil { - return am.DeleteUserFunc(accountID, executingUserID, targetUserID) + return am.DeleteUserFunc(accountID, initiatorUserID, targetUserID) } return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented") } diff --git a/management/server/peer.go b/management/server/peer.go index 0c18f1124..1e856c724 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -605,6 +605,11 @@ func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, er return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") } + err = checkIfPeerOwnerIsBlocked(peer, account) + if err != nil { + return nil, nil, err + } + if peerLoginExpired(peer, account) { return nil, nil, status.Errorf(status.PermissionDenied, "peer login has expired, please log in once more") } @@ -644,6 +649,11 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap, return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") } + err = checkIfPeerOwnerIsBlocked(peer, account) + if err != nil { + return nil, nil, err + } + updateRemotePeers := false if peerLoginExpired(peer, account) { err = checkAuth(login.UserID, peer) @@ -676,6 +686,19 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, *NetworkMap, return peer, account.GetPeerNetworkMap(peer.ID, am.dnsDomain), nil } +func checkIfPeerOwnerIsBlocked(peer *Peer, account *Account) error { + if peer.AddedWithSSOLogin() { + user, err := account.FindUser(peer.UserID) + if err != nil { + return status.Errorf(status.PermissionDenied, "user doesn't exist") + } + if user.IsBlocked() { + return status.Errorf(status.PermissionDenied, "user is blocked") + } + } + return nil +} + func checkAuth(loginUserID string, peer *Peer) error { if loginUserID == "" { // absence of a user ID indicates that JWT wasn't provided. diff --git a/management/server/user.go b/management/server/user.go index ea7dba2d2..8af8b01bd 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -51,15 +51,22 @@ type User struct { // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user AutoGroups []string PATs map[string]*PersonalAccessToken + // Blocked indicates whether the user is blocked. Blocked users can't use the system. + Blocked bool } -// IsAdmin returns true if user is an admin, false otherwise +// IsBlocked returns true if the user is blocked, false otherwise +func (u *User) IsBlocked() bool { + return u.Blocked +} + +// IsAdmin returns true if the user is an admin, false otherwise func (u *User) IsAdmin() bool { return u.Role == UserRoleAdmin } -// toUserInfo converts a User object to a UserInfo object. -func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) { +// ToUserInfo converts a User object to a UserInfo object. +func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { autoGroups := u.AutoGroups if autoGroups == nil { autoGroups = []string{} @@ -74,6 +81,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) { AutoGroups: u.AutoGroups, Status: string(UserStatusActive), IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, }, nil } if userData.ID != u.Id { @@ -93,6 +101,7 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) { AutoGroups: autoGroups, Status: string(userStatus), IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, }, nil } @@ -113,6 +122,7 @@ func (u *User) Copy() *User { IsServiceUser: u.IsServiceUser, ServiceUserName: u.ServiceUserName, PATs: pats, + Blocked: u.Blocked, } } @@ -138,7 +148,7 @@ func NewAdminUser(id string) *User { } // createServiceUser creates a new service user under the given account. -func (am *DefaultAccountManager) createServiceUser(accountID string, executingUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) { +func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, autoGroups []string) (*UserInfo, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -147,7 +157,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID) } - executingUser := account.Users[executingUserID] + executingUser := account.Users[initiatorUserID] if executingUser == nil { return nil, status.Errorf(status.NotFound, "user not found") } @@ -166,7 +176,7 @@ func (am *DefaultAccountManager) createServiceUser(accountID string, executingUs } meta := map[string]any{"name": newUser.ServiceUserName} - am.storeEvent(executingUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta) + am.storeEvent(initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta) return &UserInfo{ ID: newUser.Id, @@ -212,7 +222,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite } if user != nil { - return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account") + return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account") } users, err := am.idpManager.GetUserByEmail(invite.Email) @@ -221,7 +231,7 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite } if len(users) > 0 { - return nil, status.Errorf(status.UserAlreadyExists, "user has an existing account") + return nil, status.Errorf(status.UserAlreadyExists, "can't invite a user with an existing NetBird account") } idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID) @@ -249,12 +259,27 @@ func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite am.storeEvent(userID, newUser.Id, accountID, activity.UserInvited, nil) - return newUser.toUserInfo(idpUser) + return newUser.ToUserInfo(idpUser) } +// GetUser looks up a user by provided authorization claims. +// It will also create an account if didn't exist for this user before. +func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) { + account, _, err := am.GetAccountFromToken(claims) + if err != nil { + return nil, fmt.Errorf("failed to get account with token claims %v", err) + } + + user, ok := account.Users[claims.UserId] + if !ok { + return nil, status.Errorf(status.NotFound, "user not found") + } + return user, nil +} + // DeleteUser deletes a user from the given account. -func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, targetUserID string) error { +func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, targetUserID string) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -268,7 +293,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t return status.Errorf(status.NotFound, "user not found") } - executingUser := account.Users[executingUserID] + executingUser := account.Users[initiatorUserID] if executingUser == nil { return status.Errorf(status.NotFound, "user not found") } @@ -281,7 +306,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t } meta := map[string]any{"name": targetUser.ServiceUserName} - am.storeEvent(executingUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta) + am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta) delete(account.Users, targetUserID) @@ -294,7 +319,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, executingUserID string, t } // CreatePAT creates a new PAT for the given user -func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { +func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -316,12 +341,12 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str return nil, status.Errorf(status.NotFound, "targetUser not found") } - executingUser := account.Users[executingUserID] + executingUser := account.Users[initiatorUserID] if targetUser == nil { return nil, status.Errorf(status.NotFound, "user not found") } - if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { + if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") } @@ -338,13 +363,13 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID str } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} - am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta) + am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenCreated, meta) return pat, nil } // DeletePAT deletes a specific PAT from a user -func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error { +func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -358,12 +383,12 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str return status.Errorf(status.NotFound, "user not found") } - executingUser := account.Users[executingUserID] + executingUser := account.Users[initiatorUserID] if targetUser == nil { return status.Errorf(status.NotFound, "user not found") } - if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { + if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user") } @@ -382,7 +407,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} - am.storeEvent(executingUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) + am.storeEvent(initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) delete(targetUser.PATs, tokenID) @@ -394,7 +419,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID str } // GetPAT returns a specific PAT from a user -func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { +func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -408,12 +433,12 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string return nil, status.Errorf(status.NotFound, "user not found") } - executingUser := account.Users[executingUserID] + executingUser := account.Users[initiatorUserID] if targetUser == nil { return nil, status.Errorf(status.NotFound, "user not found") } - if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { + if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser") } @@ -426,7 +451,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string } // GetAllPATs returns all PATs for a user -func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) { +func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -440,12 +465,12 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st return nil, status.Errorf(status.NotFound, "user not found") } - executingUser := account.Users[executingUserID] + executingUser := account.Users[initiatorUserID] if targetUser == nil { return nil, status.Errorf(status.NotFound, "user not found") } - if !(executingUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { + if !(initiatorUserID == targetUserID || (executingUser.IsAdmin() && targetUser.IsServiceUser)) { return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } @@ -457,9 +482,9 @@ func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID st return pats, nil } -// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error. -// Only User.AutoGroups field is allowed to be updated for now. -func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User) (*UserInfo, error) { +// SaveUser saves updates to the given user. If the user doesn't exit it will throw status.NotFound error. +// Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. +func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, update *User) (*UserInfo, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() @@ -472,56 +497,102 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User return nil, err } + initiatorUser, err := account.FindUser(initiatorUserID) + if err != nil { + return nil, err + } + + if !initiatorUser.IsAdmin() || initiatorUser.IsBlocked() { + return nil, status.Errorf(status.PermissionDenied, "only admins are authorized to perform user update operations") + } + + oldUser := account.Users[update.Id] + if oldUser == nil { + return nil, status.Errorf(status.NotFound, "user to update doesn't exist") + } + + if initiatorUser.IsAdmin() && initiatorUserID == update.Id && oldUser.Blocked != update.Blocked { + return nil, status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") + } + + if initiatorUser.IsAdmin() && initiatorUserID == update.Id && update.Role != UserRoleAdmin { + return nil, status.Errorf(status.PermissionDenied, "admins can't change their role") + } + + // only auto groups, revoked status, and name can be updated for now + newUser := oldUser.Copy() + newUser.Role = update.Role + newUser.Blocked = update.Blocked + for _, newGroupID := range update.AutoGroups { if _, ok := account.Groups[newGroupID]; !ok { return nil, status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", newGroupID, update.Id) } } - - oldUser := account.Users[update.Id] - if oldUser == nil { - return nil, status.Errorf(status.NotFound, "update not found") - } - - // only auto groups, revoked status, and name can be updated for now - newUser := oldUser.Copy() newUser.AutoGroups = update.AutoGroups - newUser.Role = update.Role account.Users[newUser.Id] = newUser + if !oldUser.IsBlocked() && update.IsBlocked() { + // expire peers that belong to the user who's getting blocked + blockedPeers, err := account.FindUserPeers(update.Id) + if err != nil { + return nil, err + } + var peerIDs []string + for _, peer := range blockedPeers { + peerIDs = append(peerIDs, peer.ID) + peer.MarkLoginExpired(true) + account.UpdatePeer(peer) + err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status) + if err != nil { + log.Errorf("failed saving peer status while expiring peer %s", peer.ID) + return nil, err + } + } + am.peersUpdateManager.CloseChannels(peerIDs) + err = am.updateAccountPeers(account) + if err != nil { + log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID) + return nil, err + + } + } + if err = am.Store.SaveAccount(account); err != nil { return nil, err } defer func() { + // store activity logs if oldUser.Role != newUser.Role { - am.storeEvent(userID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) + am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) } - removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) - addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) - for _, g := range removedGroups { - group := account.GetGroup(g) - if group != nil { - am.storeEvent(userID, oldUser.Id, accountID, activity.GroupRemovedFromUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - } else { - log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id) + if update.AutoGroups != nil { + removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) + addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) + for _, g := range removedGroups { + group := account.GetGroup(g) + if group != nil { + am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, + map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) + } else { + log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id) + } + } - } - - for _, g := range addedGroups { - group := account.GetGroup(g) - if group != nil { - am.storeEvent(userID, oldUser.Id, accountID, activity.GroupAddedToUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - } else { - log.Errorf("group %s not found while saving user activity event of account %s", g, account.Id) + for _, g := range addedGroups { + group := account.GetGroup(g) + if group != nil { + am.storeEvent(initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, + map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) + } } } + }() if !isNil(am.idpManager) && !newUser.IsServiceUser { @@ -532,9 +603,9 @@ func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User if userData == nil { return nil, status.Errorf(status.NotFound, "user %s not found in the IdP", newUser.Id) } - return newUser.toUserInfo(userData) + return newUser.ToUserInfo(userData) } - return newUser.toUserInfo(nil) + return newUser.ToUserInfo(nil) } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist @@ -574,21 +645,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) return account, nil } -// IsUserAdmin looks up a user by his ID and returns true if he is an admin -func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { - account, _, err := am.GetAccountFromToken(claims) - if err != nil { - return false, fmt.Errorf("get account: %v", err) - } - - user, ok := account.Users[claims.UserId] - if !ok { - return false, status.Errorf(status.NotFound, "user not found") - } - - return user.Role == UserRoleAdmin, nil -} - // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // based on provided user role. func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) { @@ -625,7 +681,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( // if user is not an admin then show only current user and do not show other users continue } - info, err := accountUser.toUserInfo(nil) + info, err := accountUser.ToUserInfo(nil) if err != nil { return nil, err } @@ -642,7 +698,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ( var info *UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { - info, err = localUser.toUserInfo(queriedUser) + info, err = localUser.ToUserInfo(queriedUser) if err != nil { return nil, err } diff --git a/management/server/user_test.go b/management/server/user_test.go index 504d231e5..d6226f76d 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -265,6 +266,7 @@ func TestUser_Copy(t *testing.T) { LastUsed: time.Now(), }, }, + Blocked: false, } err := validateStruct(user) @@ -288,7 +290,7 @@ func validateStruct(s interface{}) (err error) { field := structVal.Field(i) fieldName := structType.Field(i).Name - isSet := field.IsValid() && !field.IsZero() + isSet := field.IsValid() && (!field.IsZero() || field.Type().String() == "bool") if !isSet { err = fmt.Errorf("%v%s in not set; ", err, fieldName) @@ -440,7 +442,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { assert.Errorf(t, err, "Regular users can not be deleted (yet)") } -func TestUser_IsUserAdmin_ForAdmin(t *testing.T) { +func TestDefaultAccountManager_GetUser(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") @@ -458,42 +460,23 @@ func TestUser_IsUserAdmin_ForAdmin(t *testing.T) { UserId: mockUserID, } - ok, err := am.IsUserAdmin(claims) + user, err := am.GetUser(claims) if err != nil { t.Fatalf("Error when checking user role: %s", err) } - assert.True(t, ok) + assert.Equal(t, mockUserID, user.Id) + assert.True(t, user.IsAdmin()) + assert.False(t, user.IsBlocked()) } -func TestUser_IsUserAdmin_ForUser(t *testing.T) { - store := newStore(t) - account := newAccountWithId(mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ - Id: mockUserID, - Role: "user", - } +func TestUser_IsAdmin(t *testing.T) { - err := store.SaveAccount(account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + user := NewAdminUser(mockUserID) + assert.True(t, user.IsAdmin()) - am := DefaultAccountManager{ - Store: store, - eventStore: &activity.InMemoryEventStore{}, - } - - claims := jwtclaims.AuthorizationClaims{ - UserId: mockUserID, - } - - ok, err := am.IsUserAdmin(claims) - if err != nil { - t.Fatalf("Error when checking user role: %s", err) - } - - assert.False(t, ok) + user = NewRegularUser(mockUserID) + assert.False(t, user.IsAdmin()) } func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { @@ -550,3 +533,103 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { assert.Equal(t, 1, len(users)) assert.Equal(t, mockServiceUserID, users[0].ID) } + +func TestDefaultAccountManager_SaveUser(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + regularUserID := "regularUser" + + tt := []struct { + name string + adminInitiator bool + update *User + expectedErr bool + }{ + { + name: "Should_Fail_To_Update_Admin_Role", + expectedErr: true, + adminInitiator: true, + update: &User{ + Id: userID, + Role: UserRoleUser, + Blocked: false, + }, + }, { + name: "Should_Fail_When_Admin_Blocks_Themselves", + expectedErr: true, + adminInitiator: true, + update: &User{ + Id: userID, + Role: UserRoleAdmin, + Blocked: true, + }, + }, + { + name: "Should_Fail_To_Update_Non_Existing_User", + expectedErr: true, + adminInitiator: true, + update: &User{ + Id: userID, + Role: UserRoleAdmin, + Blocked: true, + }, + }, + { + name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin", + expectedErr: true, + adminInitiator: false, + update: &User{ + Id: userID, + Role: UserRoleAdmin, + Blocked: true, + }, + }, + { + name: "Should_Update_User", + expectedErr: false, + adminInitiator: true, + update: &User{ + Id: regularUserID, + Role: UserRoleAdmin, + Blocked: true, + }, + }, + } + + for _, tc := range tt { + + // create an account and an admin user + account, err := manager.GetOrCreateAccountByUser(userID, "netbird.io") + if err != nil { + t.Fatal(err) + } + + // create a regular user + account.Users[regularUserID] = NewRegularUser(regularUserID) + err = manager.Store.SaveAccount(account) + if err != nil { + t.Fatal(err) + } + + initiatorID := userID + if !tc.adminInitiator { + initiatorID = regularUserID + } + + updated, err := manager.SaveUser(account.Id, initiatorID, tc.update) + if tc.expectedErr { + require.Errorf(t, err, "expecting SaveUser to throw an error") + } else { + require.NoError(t, err, "expecting SaveUser not to throw an error") + assert.NotNil(t, updated) + + assert.Equal(t, string(tc.update.Role), updated.Role) + assert.Equal(t, tc.update.IsBlocked(), updated.IsBlocked) + } + } + +}