mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-02 15:16:38 +00:00
[management] restructure api files (#3013)
This commit is contained in:
187
management/server/http/handlers/users/pat_handler.go
Normal file
187
management/server/http/handlers/users/pat_handler.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
// patHandler is the nameserver group handler of the account
|
||||
type patHandler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
tokenHandler := newPATsHandler(accountManager, authCfg)
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newPATsHandler creates a new patHandler HTTP handler
|
||||
func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler {
|
||||
return &patHandler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var patResponse []*api.PersonalAccessToken
|
||||
for _, pat := range pats {
|
||||
patResponse = append(patResponse, toPATResponse(pat))
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, patResponse)
|
||||
}
|
||||
|
||||
// getToken is HTTP GET handler that returns a personal access token for the given user
|
||||
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := vars["tokenId"]
|
||||
if len(tokenID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toPATResponse(pat))
|
||||
}
|
||||
|
||||
// createToken is HTTP POST handler that creates a personal access token for the given user
|
||||
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
var req api.PostApiUsersUserIdTokensJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat))
|
||||
}
|
||||
|
||||
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
tokenID := vars["tokenId"]
|
||||
if len(tokenID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
|
||||
var lastUsed *time.Time
|
||||
if !pat.LastUsed.IsZero() {
|
||||
lastUsed = &pat.LastUsed
|
||||
}
|
||||
return &api.PersonalAccessToken{
|
||||
CreatedAt: pat.CreatedAt,
|
||||
CreatedBy: pat.CreatedBy,
|
||||
Name: pat.Name,
|
||||
ExpirationDate: pat.ExpirationDate,
|
||||
Id: pat.ID,
|
||||
LastUsed: lastUsed,
|
||||
}
|
||||
}
|
||||
|
||||
func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated {
|
||||
return &api.PersonalAccessTokenGenerated{
|
||||
PlainToken: pat.PlainToken,
|
||||
PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken),
|
||||
}
|
||||
}
|
||||
255
management/server/http/handlers/users/pat_handler_test.go
Normal file
255
management/server/http/handlers/users/pat_handler_test.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
const (
|
||||
existingAccountID = "existingAccountID"
|
||||
notFoundAccountID = "notFoundAccountID"
|
||||
existingUserID = "existingUserID"
|
||||
notFoundUserID = "notFoundUserID"
|
||||
existingTokenID = "existingTokenID"
|
||||
notFoundTokenID = "notFoundTokenID"
|
||||
testDomain = "hotmail.com"
|
||||
)
|
||||
|
||||
var testAccount = &server.Account{
|
||||
Id: existingAccountID,
|
||||
Domain: testDomain,
|
||||
Users: map[string]*server.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
existingTokenID: {
|
||||
ID: existingTokenID,
|
||||
Name: "My first token",
|
||||
HashedToken: "someHash",
|
||||
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
|
||||
CreatedBy: existingUserID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: time.Now().UTC(),
|
||||
},
|
||||
"token2": {
|
||||
ID: "token2",
|
||||
Name: "My second token",
|
||||
HashedToken: "someOtherHash",
|
||||
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
|
||||
CreatedBy: existingUserID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func initPATTestData() *patHandler {
|
||||
return &patHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
return &server.PersonalAccessTokenGenerated{
|
||||
PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe",
|
||||
PersonalAccessToken: server.PersonalAccessToken{},
|
||||
}, nil
|
||||
},
|
||||
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if accountID != existingAccountID {
|
||||
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
if tokenID != existingTokenID {
|
||||
return status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
if tokenID != existingTokenID {
|
||||
return nil, status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
|
||||
}
|
||||
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
||||
},
|
||||
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHandlers(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
}{
|
||||
{
|
||||
name: "Get All Tokens",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Get Existing Token",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Get Not Existing Token",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "Delete Existing Token",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete Not Existing Token",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "POST OK",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte("{\"name\":\"name\",\"expires_in\":7}")),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
}
|
||||
|
||||
p := initPATTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||
status, tc.expectedStatus, string(content))
|
||||
return
|
||||
}
|
||||
|
||||
if !tc.expectedBody {
|
||||
return
|
||||
}
|
||||
|
||||
switch tc.name {
|
||||
case "POST OK":
|
||||
got := &api.PersonalAccessTokenGenerated{}
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.NotEmpty(t, got.PlainToken)
|
||||
assert.Equal(t, server.PATLength, len(got.PlainToken))
|
||||
case "Get All Tokens":
|
||||
expectedTokens := []api.PersonalAccessToken{
|
||||
toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]),
|
||||
toTokenResponse(*testAccount.Users[existingUserID].PATs["token2"]),
|
||||
}
|
||||
|
||||
var got []api.PersonalAccessToken
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.True(t, cmp.Equal(got, expectedTokens))
|
||||
case "Get Existing Token":
|
||||
expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID])
|
||||
got := &api.PersonalAccessToken{}
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.True(t, cmp.Equal(*got, expectedToken))
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken {
|
||||
return api.PersonalAccessToken{
|
||||
Id: serverToken.ID,
|
||||
Name: serverToken.Name,
|
||||
CreatedAt: serverToken.CreatedAt,
|
||||
LastUsed: &serverToken.LastUsed,
|
||||
CreatedBy: serverToken.CreatedBy,
|
||||
ExpirationDate: serverToken.ExpirationDate,
|
||||
}
|
||||
}
|
||||
304
management/server/http/handlers/users/users_handler.go
Normal file
304
management/server/http/handlers/users/users_handler.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
)
|
||||
|
||||
// handler is a handler that returns users of the account
|
||||
type handler struct {
|
||||
accountManager server.AccountManager
|
||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
|
||||
userHandler := newHandler(accountManager, authCfg)
|
||||
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, authCfg, router)
|
||||
}
|
||||
|
||||
// newHandler creates a new UsersHandler HTTP handler
|
||||
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithAudience(authCfg.Audience),
|
||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// updateUser is a PUT requests to update User data
|
||||
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPut {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PutApiUsersUserIdJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if req.AutoGroups == nil {
|
||||
util.WriteErrorResponse("auto_groups field can't be absent", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userRole := server.StrRoleToUserRole(req.Role)
|
||||
if userRole == server.UserRoleUnknown {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w)
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
|
||||
Id: targetUserID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
Blocked: req.IsBlocked,
|
||||
Issued: existingUser.Issued,
|
||||
IntegrationReference: existingUser.IntegrationReference,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
}
|
||||
|
||||
// deleteUser is a DELETE request to delete a user
|
||||
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// createUser creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PostApiUsersJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
|
||||
return
|
||||
}
|
||||
|
||||
email := ""
|
||||
if req.Email != nil {
|
||||
email = *req.Email
|
||||
}
|
||||
|
||||
name := ""
|
||||
if req.Name != nil {
|
||||
name = *req.Name
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: req.Role,
|
||||
AutoGroups: req.AutoGroups,
|
||||
IsServiceUser: req.IsServiceUser,
|
||||
Issued: server.UserIssuedAPI,
|
||||
})
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
|
||||
}
|
||||
|
||||
// getAllUsers returns a list of users of the account this user belongs to.
|
||||
// It also gathers additional user data (like email and name) from the IDP manager.
|
||||
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
serviceUser := r.URL.Query().Get("service_user")
|
||||
|
||||
users := make([]*api.User, 0)
|
||||
for _, d := range data {
|
||||
if d.NonDeletable {
|
||||
continue
|
||||
}
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
continue
|
||||
}
|
||||
|
||||
includeServiceUser, err := strconv.ParseBool(serviceUser)
|
||||
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
|
||||
return
|
||||
}
|
||||
if includeServiceUser == d.IsServiceUser {
|
||||
users = append(users, toUserResponse(d, claims.UserId))
|
||||
}
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, users)
|
||||
}
|
||||
|
||||
// inviteUser resend invitations to users who haven't activated their accounts,
|
||||
// prior to the expiration period.
|
||||
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
claims := h.claimsExtractor.FromRequestContext(r)
|
||||
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
|
||||
autoGroups := user.AutoGroups
|
||||
if autoGroups == nil {
|
||||
autoGroups = []string{}
|
||||
}
|
||||
|
||||
var userStatus api.UserStatus
|
||||
switch user.Status {
|
||||
case "active":
|
||||
userStatus = api.UserStatusActive
|
||||
case "invited":
|
||||
userStatus = api.UserStatusInvited
|
||||
default:
|
||||
userStatus = api.UserStatusBlocked
|
||||
}
|
||||
|
||||
if user.IsBlocked {
|
||||
userStatus = api.UserStatusBlocked
|
||||
}
|
||||
|
||||
isCurrent := user.ID == currenUserID
|
||||
return &api.User{
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
IsServiceUser: &user.IsServiceUser,
|
||||
IsBlocked: user.IsBlocked,
|
||||
LastLogin: &user.LastLogin,
|
||||
Issued: &user.Issued,
|
||||
Permissions: &api.UserPermissions{
|
||||
DashboardView: (*api.UserPermissionsDashboardView)(&user.Permissions.DashboardView),
|
||||
},
|
||||
}
|
||||
}
|
||||
468
management/server/http/handlers/users/users_handler_test.go
Normal file
468
management/server/http/handlers/users/users_handler_test.go
Normal file
@@ -0,0 +1,468 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
const (
|
||||
serviceUserID = "serviceUserID"
|
||||
nonDeletableServiceUserID = "nonDeletableServiceUserID"
|
||||
regularUserID = "regularUserID"
|
||||
)
|
||||
|
||||
var usersTestAccount = &server.Account{
|
||||
Id: existingAccountID,
|
||||
Domain: testDomain,
|
||||
Users: map[string]*server.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
Role: "admin",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
Issued: server.UserIssuedAPI,
|
||||
},
|
||||
regularUserID: {
|
||||
Id: regularUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: false,
|
||||
AutoGroups: []string{"group_1"},
|
||||
Issued: server.UserIssuedAPI,
|
||||
},
|
||||
serviceUserID: {
|
||||
Id: serviceUserID,
|
||||
Role: "user",
|
||||
IsServiceUser: true,
|
||||
AutoGroups: []string{"group_1"},
|
||||
Issued: server.UserIssuedAPI,
|
||||
},
|
||||
nonDeletableServiceUserID: {
|
||||
Id: serviceUserID,
|
||||
Role: "admin",
|
||||
IsServiceUser: true,
|
||||
NonDeletable: true,
|
||||
Issued: server.UserIssuedIntegration,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func initUsersTestData() *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return usersTestAccount.Id, claims.UserId, nil
|
||||
},
|
||||
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
|
||||
return usersTestAccount.Users[id], nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
|
||||
users := make([]*server.UserInfo, 0)
|
||||
for _, v := range usersTestAccount.Users {
|
||||
users = append(users, &server.UserInfo{
|
||||
ID: v.Id,
|
||||
Role: string(v.Role),
|
||||
Name: "",
|
||||
Email: "",
|
||||
IsServiceUser: v.IsServiceUser,
|
||||
NonDeletable: v.NonDeletable,
|
||||
Issued: v.Issued,
|
||||
})
|
||||
}
|
||||
return users, nil
|
||||
},
|
||||
CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
|
||||
if userID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||
}
|
||||
return key, nil
|
||||
},
|
||||
DeleteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if targetUserID == notFoundUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||
}
|
||||
if !usersTestAccount.Users[targetUserID].IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "user with ID %s is not a service user and can not be deleted", targetUserID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) {
|
||||
if update.Id == notFoundUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
|
||||
}
|
||||
|
||||
if userID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
|
||||
}
|
||||
|
||||
info, err := update.Copy().ToUserInfo(nil, &server.Settings{RegularUsersViewBlocked: false})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return info, nil
|
||||
},
|
||||
InviteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if initiatorUserID != existingUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID)
|
||||
}
|
||||
|
||||
if targetUserID == notFoundUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUsers(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestType string
|
||||
requestPath string
|
||||
expectedUserIDs []string
|
||||
}{
|
||||
{name: "getAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}},
|
||||
{name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}},
|
||||
{name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
|
||||
userHandler.getAllUsers(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||
status, tc.expectedStatus, string(content))
|
||||
return
|
||||
}
|
||||
|
||||
respBody := []*server.UserInfo{}
|
||||
err = json.Unmarshal(content, &respBody)
|
||||
if err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, len(respBody), len(tc.expectedUserIDs))
|
||||
for _, v := range respBody {
|
||||
assert.Contains(t, tc.expectedUserIDs, v.ID)
|
||||
assert.Equal(t, v.ID, usersTestAccount.Users[v.ID].Id)
|
||||
assert.Equal(t, v.Role, string(usersTestAccount.Users[v.ID].Role))
|
||||
assert.Equal(t, v.IsServiceUser, usersTestAccount.Users[v.ID].IsServiceUser)
|
||||
assert.Equal(t, v.Issued, usersTestAccount.Users[v.ID].Issued)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatusCode int
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
expectedUserID string
|
||||
expectedRole string
|
||||
expectedStatus string
|
||||
expectedBlocked bool
|
||||
expectedIsServiceUser bool
|
||||
expectedGroups []string
|
||||
}{
|
||||
{
|
||||
name: "Update_Block_User",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: true,
|
||||
expectedRole: "user",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_1"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"user\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": true}"),
|
||||
},
|
||||
{
|
||||
name: "Update_Change_Role_To_Admin",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: false,
|
||||
expectedRole: "admin",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_1"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": false}"),
|
||||
},
|
||||
{
|
||||
name: "Update_Groups",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: false,
|
||||
expectedRole: "admin",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_2", "group_3"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_3\", \"group_2\"],\"is_service_user\":false, \"is_blocked\": false}"),
|
||||
},
|
||||
{
|
||||
name: "Should_Fail_Because_AutoGroups_Is_Absent",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
expectedUserID: regularUserID,
|
||||
expectedBlocked: false,
|
||||
expectedRole: "admin",
|
||||
expectedStatus: "blocked",
|
||||
expectedGroups: []string{"group_2", "group_3"},
|
||||
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"is_service_user\":false, \"is_blocked\": false}"),
|
||||
},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatusCode {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, http.StatusOK)
|
||||
}
|
||||
|
||||
if tc.expectedStatusCode == 200 {
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
respBody := &api.User{}
|
||||
err = json.Unmarshal(content, &respBody)
|
||||
if err != nil {
|
||||
t.Fatalf("response content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedUserID, respBody.Id)
|
||||
assert.Equal(t, tc.expectedRole, respBody.Role)
|
||||
assert.Equal(t, tc.expectedIsServiceUser, *respBody.IsServiceUser)
|
||||
assert.Equal(t, tc.expectedBlocked, respBody.IsBlocked)
|
||||
assert.Len(t, respBody.AutoGroups, len(tc.expectedGroups))
|
||||
|
||||
for _, expectedGroup := range tc.expectedGroups {
|
||||
exists := false
|
||||
for _, actualGroup := range respBody.AutoGroups {
|
||||
if expectedGroup == actualGroup {
|
||||
exists = true
|
||||
}
|
||||
}
|
||||
assert.True(t, exists, fmt.Sprintf("group %s not found in the response", expectedGroup))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateUser(t *testing.T) {
|
||||
name := "name"
|
||||
email := "email"
|
||||
serviceUserToAdd := api.UserCreateRequest{
|
||||
AutoGroups: []string{},
|
||||
Email: nil,
|
||||
IsServiceUser: true,
|
||||
Name: &name,
|
||||
Role: "admin",
|
||||
}
|
||||
serviceUserString, err := json.Marshal(serviceUserToAdd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
regularUserToAdd := api.UserCreateRequest{
|
||||
AutoGroups: []string{},
|
||||
Email: &email,
|
||||
IsServiceUser: true,
|
||||
Name: &name,
|
||||
Role: "admin",
|
||||
}
|
||||
regularUserString, err := json.Marshal(regularUserToAdd)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
expectedResult []*server.User
|
||||
}{
|
||||
{name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)},
|
||||
// right now creation is blocked in AC middleware, will be refactored in the future
|
||||
{name: "CreateRegularUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(regularUserString)},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.createUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := rr.Code; status != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, tc.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestType string
|
||||
requestPath string
|
||||
requestVars map[string]string
|
||||
}{
|
||||
{
|
||||
name: "Invite User with Existing User",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/users/" + existingUserID + "/invite",
|
||||
expectedStatus: http.StatusOK,
|
||||
requestVars: map[string]string{"userId": existingUserID},
|
||||
},
|
||||
{
|
||||
name: "Invite User with missing user_id",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/users/" + notFoundUserID + "/invite",
|
||||
expectedStatus: http.StatusNotFound,
|
||||
requestVars: map[string]string{"userId": notFoundUserID},
|
||||
},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.inviteUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := rr.Code; status != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, tc.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUser(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
requestType string
|
||||
requestPath string
|
||||
requestVars map[string]string
|
||||
requestBody io.Reader
|
||||
}{
|
||||
{
|
||||
name: "Delete Regular User",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + regularUserID,
|
||||
requestVars: map[string]string{"userId": regularUserID},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "Delete Service User",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + serviceUserID,
|
||||
requestVars: map[string]string{"userId": serviceUserID},
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete Not Existing User",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + notFoundUserID,
|
||||
requestVars: map[string]string{"userId": notFoundUserID},
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
userHandler := initUsersTestData()
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
|
||||
req = mux.SetURLVars(req, tc.requestVars)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.deleteUser(rr, req)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
if status := rr.Code; status != tc.expectedStatus {
|
||||
t.Fatalf("handler returned wrong status code: got %v want %v",
|
||||
status, tc.expectedStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user