use api wrapper for permissions management

This commit is contained in:
pascal
2026-02-24 00:39:54 +01:00
parent 9d123ec059
commit 32730af33f
43 changed files with 867 additions and 1320 deletions

View File

@@ -10,9 +10,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -55,14 +58,14 @@ type invitesHandler struct {
}
// AddInvitesEndpoints registers invite-related endpoints
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
h := &invitesHandler{accountManager: accountManager}
// Authenticated endpoints (require admin)
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Read, h.listInvites)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Create, h.createInvite)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", permissionsManager.WithPermission(modules.Users, operations.Delete, h.deleteInvite)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", permissionsManager.WithPermission(modules.Users, operations.Update, h.regenerateInvite)).Methods("POST", "OPTIONS")
}
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
@@ -79,14 +82,7 @@ func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Route
}
// listInvites handles GET /api/users/invites
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -102,14 +98,7 @@ func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
}
// createInvite handles POST /api/users/invites
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.UserInviteCreateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -191,18 +180,12 @@ func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
}
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
@@ -238,14 +221,7 @@ func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request
}
// deleteInvite handles DELETE /api/users/invites/{inviteId}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
@@ -253,7 +229,7 @@ func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
err := h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -110,7 +110,11 @@ func TestListInvites(t *testing.T) {
})
rr := httptest.NewRecorder()
handler.listInvites(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.listInvites(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -235,7 +239,11 @@ func TestCreateInvite(t *testing.T) {
})
rr := httptest.NewRecorder()
handler.createInvite(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.createInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -573,7 +581,11 @@ func TestRegenerateInvite(t *testing.T) {
}
rr := httptest.NewRecorder()
handler.regenerateInvite(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.regenerateInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -651,7 +663,11 @@ func TestDeleteInvite(t *testing.T) {
}
rr := httptest.NewRecorder()
handler.deleteInvite(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.deleteInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
})

View File

@@ -7,8 +7,11 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -19,12 +22,12 @@ type patHandler struct {
accountManager account.Manager
}
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router) {
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
tokenHandler := newPATsHandler(accountManager)
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getAllTokens)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Create, tokenHandler.createToken)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getToken)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Delete, tokenHandler.deleteToken)).Methods("DELETE", "OPTIONS")
}
// newPATsHandler creates a new patHandler HTTP handler
@@ -35,22 +38,15 @@ func newPATsHandler(accountManager account.Manager) *patHandler {
}
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(userID) == 0 {
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
pats, err := h.accountManager.GetAllPATs(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -65,14 +61,7 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
}
// getToken is HTTP GET handler that returns a personal access token for the given user
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -86,7 +75,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
return
}
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
pat, err := h.accountManager.GetPAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -96,14 +85,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
}
// createToken is HTTP POST handler that creates a personal access token for the given user
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -112,13 +94,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
}
var req api.PostApiUsersUserIdTokensJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
pat, err := h.accountManager.CreatePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.Name, req.ExpiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -128,14 +110,7 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
}
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -149,7 +124,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
err := h.accountManager.DeletePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -24,6 +24,18 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
const (
existingAccountID = "existingAccountID"
notFoundAccountID = "notFoundAccountID"
@@ -181,10 +193,10 @@ func TestTokenHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
router.HandleFunc("/api/users/{userId}/tokens", wrapHandler(p.getAllTokens)).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", wrapHandler(p.getToken)).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", wrapHandler(p.createToken)).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", wrapHandler(p.deleteToken)).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -9,13 +9,15 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
nbcontext "github.com/netbirdio/netbird/management/server/context"
)
// handler is a handler that returns users of the account
@@ -23,18 +25,18 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
userHandler := newHandler(accountManager)
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS")
addUsersTokensEndpoint(accountManager, router)
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getAllUsers)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/current", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getCurrentUser)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.updateUser)).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.deleteUser)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.createUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.inviteUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.approveUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.rejectUser)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/password", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.changePassword)).Methods("PUT", "OPTIONS")
addUsersTokensEndpoint(accountManager, router, permissionsManager)
}
// newHandler creates a new UsersHandler HTTP handler
@@ -45,19 +47,12 @@ func newHandler(accountManager account.Manager) *handler {
}
// updateUser is a PUT requests to update User data
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -89,7 +84,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
return
}
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{
newUser, err := h.accountManager.SaveUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.User{
Id: targetUserID,
Role: userRole,
AutoGroups: req.AutoGroups,
@@ -102,23 +97,16 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
}
// deleteUser is a DELETE request to delete a user
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -126,7 +114,7 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
err := h.accountManager.DeleteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -136,21 +124,14 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
}
// createUser creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -171,7 +152,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name
}
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{
newUser, err := h.accountManager.CreateUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.UserInfo{
Email: email,
Name: name,
Role: req.Role,
@@ -183,25 +164,18 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
}
// getAllUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager.
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
data, err := h.accountManager.GetUsersFromAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -215,7 +189,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
continue
}
if serviceUser == "" {
users = append(users, toUserResponse(d, userID))
users = append(users, toUserResponse(d, userAuth.UserId))
continue
}
@@ -226,7 +200,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
return
}
if includeServiceUser == d.IsServiceUser {
users = append(users, toUserResponse(d, userID))
users = append(users, toUserResponse(d, userAuth.UserId))
}
}
@@ -235,19 +209,12 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
// inviteUser resend invitations to users who haven't activated their accounts,
// prior to the expiration period.
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -255,7 +222,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
err := h.accountManager.InviteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -264,19 +231,13 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth)
user, err := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -356,7 +317,7 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
}
// approveUser is a POST request to approve a user that is pending approval
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -369,11 +330,6 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -385,7 +341,7 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
}
// rejectUser is a DELETE request to reject a user that is pending approval
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -398,12 +354,7 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
err := h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -421,7 +372,7 @@ type passwordChangeRequest struct {
// changePassword is a PUT request to change user's password.
// Only available when embedded IDP is enabled.
// Users can only change their own password.
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -434,19 +385,13 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req passwordChangeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
err := h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -232,7 +232,11 @@ func TestGetUsers(t *testing.T) {
AccountId: existingAccountID,
})
userHandler.getAllUsers(recorder, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.getAllUsers(recorder, req, userAuth)
res := recorder.Result()
defer res.Body.Close()
@@ -343,7 +347,7 @@ func TestUpdateUser(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
router.HandleFunc("/api/users/{userId}", wrapHandler(userHandler.updateUser)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -439,7 +443,11 @@ func TestCreateUser(t *testing.T) {
AccountId: existingAccountID,
})
userHandler.createUser(rr, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.createUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -490,7 +498,11 @@ func TestInviteUser(t *testing.T) {
rr := httptest.NewRecorder()
userHandler.inviteUser(rr, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.inviteUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -549,7 +561,11 @@ func TestDeleteUser(t *testing.T) {
rr := httptest.NewRecorder()
userHandler.deleteUser(rr, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.deleteUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -682,7 +698,11 @@ func TestCurrentUser(t *testing.T) {
rr := httptest.NewRecorder()
userHandler.getCurrentUser(rr, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.getCurrentUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -779,7 +799,7 @@ func TestApproveUserEndpoint(t *testing.T) {
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
router.HandleFunc("/users/{userId}/approve", wrapHandler(handler.approveUser)).Methods("POST")
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
require.NoError(t, err)