mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 16:56:39 +00:00
[management] refactor auth (#3296)
This commit is contained in:
@@ -9,39 +9,33 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
|
||||
// handler is a handler that returns users of the account
|
||||
type handler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
userHandler := newHandler(accountManager, authCfg)
|
||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
||||
userHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, authCfg, router)
|
||||
addUsersTokensEndpoint(accountManager, router)
|
||||
}
|
||||
|
||||
// newHandler creates a new UsersHandler HTTP handler
|
||||
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||
func newHandler(accountManager server.AccountManager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,13 +46,13 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -103,7 +97,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
}
|
||||
|
||||
// deleteUser is a DELETE request to delete a user
|
||||
@@ -113,13 +107,13 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -143,12 +137,12 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
req := &api.PostApiUsersJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
@@ -184,7 +178,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
}
|
||||
|
||||
// getAllUsers returns a list of users of the account this user belongs to.
|
||||
@@ -195,13 +189,13 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -216,7 +210,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
continue
|
||||
}
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -227,7 +221,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if includeServiceUser == d.IsServiceUser {
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,12 +236,12 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
|
||||
Reference in New Issue
Block a user