mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
[management] refactor auth (#3296)
This commit is contained in:
@@ -7,22 +7,20 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
// patHandler is the nameserver group handler of the account
|
||||
type patHandler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
tokenHandler := newPATsHandler(accountManager, authCfg)
|
||||
func addUsersTokensEndpoint(accountManager server.AccountManager, router *mux.Router) {
|
||||
tokenHandler := newPATsHandler(accountManager)
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
|
||||
@@ -30,25 +28,21 @@ func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg config
|
||||
}
|
||||
|
||||
// newPATsHandler creates a new patHandler HTTP handler
|
||||
func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler {
|
||||
func newPATsHandler(accountManager server.AccountManager) *patHandler {
|
||||
return &patHandler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
@@ -72,13 +66,13 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// getToken is HTTP GET handler that returns a personal access token for the given user
|
||||
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -103,13 +97,13 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// createToken is HTTP POST handler that creates a personal access token for the given user
|
||||
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -135,13 +129,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -77,10 +78,6 @@ func initPATTestData() *patHandler {
|
||||
PersonalAccessToken: types.PersonalAccessToken{},
|
||||
}, nil
|
||||
},
|
||||
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if accountID != existingAccountID {
|
||||
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
@@ -115,15 +112,6 @@ func initPATTestData() *patHandler {
|
||||
return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,6 +173,11 @@ func TestTokenHandlers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
|
||||
|
||||
@@ -9,39 +9,33 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
|
||||
// handler is a handler that returns users of the account
|
||||
type handler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
accountManager server.AccountManager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
userHandler := newHandler(accountManager, authCfg)
|
||||
func AddEndpoints(accountManager server.AccountManager, router *mux.Router) {
|
||||
userHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, authCfg, router)
|
||||
addUsersTokensEndpoint(accountManager, router)
|
||||
}
|
||||
|
||||
// newHandler creates a new UsersHandler HTTP handler
|
||||
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||
func newHandler(accountManager server.AccountManager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,13 +46,13 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -103,7 +97,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
}
|
||||
|
||||
// deleteUser is a DELETE request to delete a user
|
||||
@@ -113,13 +107,13 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -143,12 +137,12 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
req := &api.PostApiUsersJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
@@ -184,7 +178,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
}
|
||||
|
||||
// getAllUsers returns a list of users of the account this user belongs to.
|
||||
@@ -195,13 +189,13 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -216,7 +210,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
continue
|
||||
}
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -227,7 +221,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if includeServiceUser == d.IsServiceUser {
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,12 +236,12 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
|
||||
@@ -13,8 +13,8 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -64,9 +64,6 @@ var usersTestAccount = &types.Account{
|
||||
func initUsersTestData() *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return usersTestAccount.Id, claims.UserId, nil
|
||||
},
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) {
|
||||
return usersTestAccount.Users[id], nil
|
||||
},
|
||||
@@ -127,15 +124,6 @@ func initUsersTestData() *handler {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,6 +146,11 @@ func TestGetUsers(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userHandler.getAllUsers(recorder, req)
|
||||
|
||||
@@ -263,6 +256,11 @@ func TestUpdateUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
|
||||
@@ -355,6 +353,11 @@ func TestCreateUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
rr := httptest.NewRecorder()
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userHandler.createUser(rr, req)
|
||||
|
||||
@@ -399,6 +402,12 @@ func TestInviteUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.inviteUser(rr, req)
|
||||
@@ -452,6 +461,12 @@ func TestDeleteUser(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
req = nbcontext.SetUserAuthInRequest(req, nbcontext.UserAuth{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.deleteUser(rr, req)
|
||||
|
||||
Reference in New Issue
Block a user