mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 02:06:39 +00:00
use api wrapper for permissions management
This commit is contained in:
@@ -10,9 +10,12 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -55,14 +58,14 @@ type invitesHandler struct {
|
||||
}
|
||||
|
||||
// AddInvitesEndpoints registers invite-related endpoints
|
||||
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
h := &invitesHandler{accountManager: accountManager}
|
||||
|
||||
// Authenticated endpoints (require admin)
|
||||
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Read, h.listInvites)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Create, h.createInvite)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}", permissionsManager.WithPermission(modules.Users, operations.Delete, h.deleteInvite)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}/regenerate", permissionsManager.WithPermission(modules.Users, operations.Update, h.regenerateInvite)).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
|
||||
@@ -79,14 +82,7 @@ func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Route
|
||||
}
|
||||
|
||||
// listInvites handles GET /api/users/invites
|
||||
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -102,14 +98,7 @@ func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createInvite handles POST /api/users/invites
|
||||
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.UserInviteCreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -191,18 +180,12 @@ func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
|
||||
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
@@ -238,14 +221,7 @@ func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// deleteInvite handles DELETE /api/users/invites/{inviteId}
|
||||
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
@@ -253,7 +229,7 @@ func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
|
||||
err := h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -110,7 +110,11 @@ func TestListInvites(t *testing.T) {
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.listInvites(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.listInvites(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -235,7 +239,11 @@ func TestCreateInvite(t *testing.T) {
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.createInvite(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.createInvite(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -573,7 +581,11 @@ func TestRegenerateInvite(t *testing.T) {
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.regenerateInvite(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.regenerateInvite(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -651,7 +663,11 @@ func TestDeleteInvite(t *testing.T) {
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.deleteInvite(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.deleteInvite(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
})
|
||||
|
||||
@@ -7,8 +7,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -19,12 +22,12 @@ type patHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
tokenHandler := newPATsHandler(accountManager)
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getAllTokens)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Create, tokenHandler.createToken)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getToken)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Delete, tokenHandler.deleteToken)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newPATsHandler creates a new patHandler HTTP handler
|
||||
@@ -35,22 +38,15 @@ func newPATsHandler(accountManager account.Manager) *patHandler {
|
||||
}
|
||||
|
||||
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -65,14 +61,7 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getToken is HTTP GET handler that returns a personal access token for the given user
|
||||
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -86,7 +75,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -96,14 +85,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createToken is HTTP POST handler that creates a personal access token for the given user
|
||||
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -112,13 +94,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req api.PostApiUsersUserIdTokensJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.Name, req.ExpiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -128,14 +110,7 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -149,7 +124,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
err := h.accountManager.DeletePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -24,6 +24,18 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
existingAccountID = "existingAccountID"
|
||||
notFoundAccountID = "notFoundAccountID"
|
||||
@@ -181,10 +193,10 @@ func TestTokenHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", wrapHandler(p.getAllTokens)).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", wrapHandler(p.getToken)).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", wrapHandler(p.createToken)).Methods("POST")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", wrapHandler(p.deleteToken)).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -9,13 +9,15 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
|
||||
// handler is a handler that returns users of the account
|
||||
@@ -23,18 +25,18 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
userHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, router)
|
||||
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getAllUsers)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/current", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getCurrentUser)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.updateUser)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.deleteUser)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.createUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.inviteUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/approve", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.approveUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/reject", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.rejectUser)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/password", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.changePassword)).Methods("PUT", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, router, permissionsManager)
|
||||
}
|
||||
|
||||
// newHandler creates a new UsersHandler HTTP handler
|
||||
@@ -45,19 +47,12 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// updateUser is a PUT requests to update User data
|
||||
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPut {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -89,7 +84,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.User{
|
||||
Id: targetUserID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
@@ -102,23 +97,16 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
|
||||
}
|
||||
|
||||
// deleteUser is a DELETE request to delete a user
|
||||
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -126,7 +114,7 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
|
||||
err := h.accountManager.DeleteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -136,21 +124,14 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createUser creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
req := &api.PostApiUsersJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -171,7 +152,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
name = *req.Name
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.UserInfo{
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: req.Role,
|
||||
@@ -183,25 +164,18 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
|
||||
}
|
||||
|
||||
// getAllUsers returns a list of users of the account this user belongs to.
|
||||
// It also gathers additional user data (like email and name) from the IDP manager.
|
||||
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -215,7 +189,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
continue
|
||||
}
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
users = append(users, toUserResponse(d, userAuth.UserId))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -226,7 +200,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if includeServiceUser == d.IsServiceUser {
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
users = append(users, toUserResponse(d, userAuth.UserId))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,19 +209,12 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// inviteUser resend invitations to users who haven't activated their accounts,
|
||||
// prior to the expiration period.
|
||||
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -255,7 +222,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
|
||||
err := h.accountManager.InviteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -264,19 +231,13 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth)
|
||||
user, err := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -356,7 +317,7 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
|
||||
}
|
||||
|
||||
// approveUser is a POST request to approve a user that is pending approval
|
||||
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -369,11 +330,6 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -385,7 +341,7 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// rejectUser is a DELETE request to reject a user that is pending approval
|
||||
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -398,12 +354,7 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
err := h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -421,7 +372,7 @@ type passwordChangeRequest struct {
|
||||
// changePassword is a PUT request to change user's password.
|
||||
// Only available when embedded IDP is enabled.
|
||||
// Users can only change their own password.
|
||||
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPut {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -434,19 +385,13 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req passwordChangeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
|
||||
err := h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -232,7 +232,11 @@ func TestGetUsers(t *testing.T) {
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userHandler.getAllUsers(recorder, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.getAllUsers(recorder, req, userAuth)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -343,7 +347,7 @@ func TestUpdateUser(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
|
||||
router.HandleFunc("/api/users/{userId}", wrapHandler(userHandler.updateUser)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -439,7 +443,11 @@ func TestCreateUser(t *testing.T) {
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userHandler.createUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.createUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -490,7 +498,11 @@ func TestInviteUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.inviteUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.inviteUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -549,7 +561,11 @@ func TestDeleteUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.deleteUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.deleteUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -682,7 +698,11 @@ func TestCurrentUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.getCurrentUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.getCurrentUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -779,7 +799,7 @@ func TestApproveUserEndpoint(t *testing.T) {
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
|
||||
router.HandleFunc("/users/{userId}/approve", wrapHandler(handler.approveUser)).Methods("POST")
|
||||
|
||||
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Reference in New Issue
Block a user