[management] restructure api files (#3013)

This commit is contained in:
Pascal Fischer
2024-12-10 15:59:25 +01:00
committed by GitHub
parent 97bb74f824
commit 6142828a9c
30 changed files with 461 additions and 426 deletions

View File

@@ -0,0 +1,187 @@
package users
import (
"encoding/json"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// patHandler is the nameserver group handler of the account
type patHandler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
tokenHandler := newPATsHandler(accountManager, authCfg)
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
}
// newPATsHandler creates a new patHandler HTTP handler
func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler {
return &patHandler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(userID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var patResponse []*api.PersonalAccessToken
for _, pat := range pats {
patResponse = append(patResponse, toPATResponse(pat))
}
util.WriteJSONObject(r.Context(), w, patResponse)
}
// getToken is HTTP GET handler that returns a personal access token for the given user
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
tokenID := vars["tokenId"]
if len(tokenID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
return
}
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toPATResponse(pat))
}
// createToken is HTTP POST handler that creates a personal access token for the given user
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
var req api.PostApiUsersUserIdTokensJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat))
}
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
tokenID := vars["tokenId"]
if len(tokenID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid token ID"), w)
return
}
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken {
var lastUsed *time.Time
if !pat.LastUsed.IsZero() {
lastUsed = &pat.LastUsed
}
return &api.PersonalAccessToken{
CreatedAt: pat.CreatedAt,
CreatedBy: pat.CreatedBy,
Name: pat.Name,
ExpirationDate: pat.ExpirationDate,
Id: pat.ID,
LastUsed: lastUsed,
}
}
func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated {
return &api.PersonalAccessTokenGenerated{
PlainToken: pat.PlainToken,
PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken),
}
}

View File

@@ -0,0 +1,255 @@
package users
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
)
const (
existingAccountID = "existingAccountID"
notFoundAccountID = "notFoundAccountID"
existingUserID = "existingUserID"
notFoundUserID = "notFoundUserID"
existingTokenID = "existingTokenID"
notFoundTokenID = "notFoundTokenID"
testDomain = "hotmail.com"
)
var testAccount = &server.Account{
Id: existingAccountID,
Domain: testDomain,
Users: map[string]*server.User{
existingUserID: {
Id: existingUserID,
PATs: map[string]*server.PersonalAccessToken{
existingTokenID: {
ID: existingTokenID,
Name: "My first token",
HashedToken: "someHash",
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
CreatedBy: existingUserID,
CreatedAt: time.Now().UTC(),
LastUsed: time.Now().UTC(),
},
"token2": {
ID: "token2",
Name: "My second token",
HashedToken: "someOtherHash",
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
CreatedBy: existingUserID,
CreatedAt: time.Now().UTC(),
LastUsed: time.Now().UTC(),
},
},
},
},
}
func initPATTestData() *patHandler {
return &patHandler{
accountManager: &mock_server.MockAccountManager{
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
if targetUserID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
}
return &server.PersonalAccessTokenGenerated{
PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe",
PersonalAccessToken: server.PersonalAccessToken{},
}, nil
},
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return claims.AccountId, claims.UserId, nil
},
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
if targetUserID != existingUserID {
return status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
}
if tokenID != existingTokenID {
return status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
}
return nil
},
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
if targetUserID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
}
if tokenID != existingTokenID {
return nil, status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
}
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
},
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
if accountID != existingAccountID {
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
}
if targetUserID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
}
return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
}
}),
),
}
}
func TestTokenHandlers(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "Get All Tokens",
requestType: http.MethodGet,
requestPath: "/api/users/" + existingUserID + "/tokens",
expectedStatus: http.StatusOK,
expectedBody: true,
},
{
name: "Get Existing Token",
requestType: http.MethodGet,
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
expectedStatus: http.StatusOK,
expectedBody: true,
},
{
name: "Get Not Existing Token",
requestType: http.MethodGet,
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
expectedStatus: http.StatusNotFound,
},
{
name: "Delete Existing Token",
requestType: http.MethodDelete,
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
expectedStatus: http.StatusOK,
},
{
name: "Delete Not Existing Token",
requestType: http.MethodDelete,
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
expectedStatus: http.StatusNotFound,
},
{
name: "POST OK",
requestType: http.MethodPost,
requestPath: "/api/users/" + existingUserID + "/tokens",
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"expires_in\":7}")),
expectedStatus: http.StatusOK,
expectedBody: true,
},
}
p := initPATTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
switch tc.name {
case "POST OK":
got := &api.PersonalAccessTokenGenerated{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.NotEmpty(t, got.PlainToken)
assert.Equal(t, server.PATLength, len(got.PlainToken))
case "Get All Tokens":
expectedTokens := []api.PersonalAccessToken{
toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]),
toTokenResponse(*testAccount.Users[existingUserID].PATs["token2"]),
}
var got []api.PersonalAccessToken
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.True(t, cmp.Equal(got, expectedTokens))
case "Get Existing Token":
expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID])
got := &api.PersonalAccessToken{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.True(t, cmp.Equal(*got, expectedToken))
}
})
}
}
func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken {
return api.PersonalAccessToken{
Id: serverToken.ID,
Name: serverToken.Name,
CreatedAt: serverToken.CreatedAt,
LastUsed: &serverToken.LastUsed,
CreatedBy: serverToken.CreatedBy,
ExpirationDate: serverToken.ExpirationDate,
}
}

View File

@@ -0,0 +1,304 @@
package users
import (
"encoding/json"
"net/http"
"strconv"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
// handler is a handler that returns users of the account
type handler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) {
userHandler := newHandler(accountManager, authCfg)
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
addUsersTokensEndpoint(accountManager, authCfg, router)
}
// newHandler creates a new UsersHandler HTTP handler
func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler {
return &handler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// updateUser is a PUT requests to update User data
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
req := &api.PutApiUsersUserIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.AutoGroups == nil {
util.WriteErrorResponse("auto_groups field can't be absent", http.StatusBadRequest, w)
return
}
userRole := server.StrRoleToUserRole(req.Role)
if userRole == server.UserRoleUnknown {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w)
return
}
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
Id: targetUserID,
Role: userRole,
AutoGroups: req.AutoGroups,
Blocked: req.IsBlocked,
Issued: existingUser.Issued,
IntegrationReference: existingUser.IntegrationReference,
})
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
}
// deleteUser is a DELETE request to delete a user
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
// createUser creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w)
return
}
email := ""
if req.Email != nil {
email = *req.Email
}
name := ""
if req.Name != nil {
name = *req.Name
}
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
Email: email,
Name: name,
Role: req.Role,
AutoGroups: req.AutoGroups,
IsServiceUser: req.IsServiceUser,
Issued: server.UserIssuedAPI,
})
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId))
}
// getAllUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager.
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
serviceUser := r.URL.Query().Get("service_user")
users := make([]*api.User, 0)
for _, d := range data {
if d.NonDeletable {
continue
}
if serviceUser == "" {
users = append(users, toUserResponse(d, claims.UserId))
continue
}
includeServiceUser, err := strconv.ParseBool(serviceUser)
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
return
}
if includeServiceUser == d.IsServiceUser {
users = append(users, toUserResponse(d, claims.UserId))
}
}
util.WriteJSONObject(r.Context(), w, users)
}
// inviteUser resend invitations to users who haven't activated their accounts,
// prior to the expiration period.
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
claims := h.claimsExtractor.FromRequestContext(r)
accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func toUserResponse(user *server.UserInfo, currenUserID string) *api.User {
autoGroups := user.AutoGroups
if autoGroups == nil {
autoGroups = []string{}
}
var userStatus api.UserStatus
switch user.Status {
case "active":
userStatus = api.UserStatusActive
case "invited":
userStatus = api.UserStatusInvited
default:
userStatus = api.UserStatusBlocked
}
if user.IsBlocked {
userStatus = api.UserStatusBlocked
}
isCurrent := user.ID == currenUserID
return &api.User{
Id: user.ID,
Name: user.Name,
Email: user.Email,
Role: user.Role,
AutoGroups: autoGroups,
Status: userStatus,
IsCurrent: &isCurrent,
IsServiceUser: &user.IsServiceUser,
IsBlocked: user.IsBlocked,
LastLogin: &user.LastLogin,
Issued: &user.Issued,
Permissions: &api.UserPermissions{
DashboardView: (*api.UserPermissionsDashboardView)(&user.Permissions.DashboardView),
},
}
}

View File

@@ -0,0 +1,468 @@
package users
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/status"
)
const (
serviceUserID = "serviceUserID"
nonDeletableServiceUserID = "nonDeletableServiceUserID"
regularUserID = "regularUserID"
)
var usersTestAccount = &server.Account{
Id: existingAccountID,
Domain: testDomain,
Users: map[string]*server.User{
existingUserID: {
Id: existingUserID,
Role: "admin",
IsServiceUser: false,
AutoGroups: []string{"group_1"},
Issued: server.UserIssuedAPI,
},
regularUserID: {
Id: regularUserID,
Role: "user",
IsServiceUser: false,
AutoGroups: []string{"group_1"},
Issued: server.UserIssuedAPI,
},
serviceUserID: {
Id: serviceUserID,
Role: "user",
IsServiceUser: true,
AutoGroups: []string{"group_1"},
Issued: server.UserIssuedAPI,
},
nonDeletableServiceUserID: {
Id: serviceUserID,
Role: "admin",
IsServiceUser: true,
NonDeletable: true,
Issued: server.UserIssuedIntegration,
},
},
}
func initUsersTestData() *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
return usersTestAccount.Id, claims.UserId, nil
},
GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
return usersTestAccount.Users[id], nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)
for _, v := range usersTestAccount.Users {
users = append(users, &server.UserInfo{
ID: v.Id,
Role: string(v.Role),
Name: "",
Email: "",
IsServiceUser: v.IsServiceUser,
NonDeletable: v.NonDeletable,
Issued: v.Issued,
})
}
return users, nil
},
CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) {
if userID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
return key, nil
},
DeleteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if targetUserID == notFoundUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
}
if !usersTestAccount.Users[targetUserID].IsServiceUser {
return status.Errorf(status.PermissionDenied, "user with ID %s is not a service user and can not be deleted", targetUserID)
}
return nil
},
SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) {
if update.Id == notFoundUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id)
}
if userID != existingUserID {
return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID)
}
info, err := update.Copy().ToUserInfo(nil, &server.Settings{RegularUsersViewBlocked: false})
if err != nil {
return nil, err
}
return info, nil
},
InviteUserFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) error {
if initiatorUserID != existingUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", initiatorUserID)
}
if targetUserID == notFoundUserID {
return status.Errorf(status.NotFound, "user with ID %s does not exists", targetUserID)
}
return nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: existingUserID,
Domain: testDomain,
AccountId: existingAccountID,
}
}),
),
}
}
func TestGetUsers(t *testing.T) {
tt := []struct {
name string
expectedStatus int
requestType string
requestPath string
expectedUserIDs []string
}{
{name: "getAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}},
{name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}},
{name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
userHandler.getAllUsers(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
respBody := []*server.UserInfo{}
err = json.Unmarshal(content, &respBody)
if err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, len(respBody), len(tc.expectedUserIDs))
for _, v := range respBody {
assert.Contains(t, tc.expectedUserIDs, v.ID)
assert.Equal(t, v.ID, usersTestAccount.Users[v.ID].Id)
assert.Equal(t, v.Role, string(usersTestAccount.Users[v.ID].Role))
assert.Equal(t, v.IsServiceUser, usersTestAccount.Users[v.ID].IsServiceUser)
assert.Equal(t, v.Issued, usersTestAccount.Users[v.ID].Issued)
}
})
}
}
func TestUpdateUser(t *testing.T) {
tt := []struct {
name string
expectedStatusCode int
requestType string
requestPath string
requestBody io.Reader
expectedUserID string
expectedRole string
expectedStatus string
expectedBlocked bool
expectedIsServiceUser bool
expectedGroups []string
}{
{
name: "Update_Block_User",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: true,
expectedRole: "user",
expectedStatus: "blocked",
expectedGroups: []string{"group_1"},
requestBody: bytes.NewBufferString("{\"role\":\"user\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": true}"),
},
{
name: "Update_Change_Role_To_Admin",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_1"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_1\"],\"is_service_user\":false, \"is_blocked\": false}"),
},
{
name: "Update_Groups",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusOK,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_2", "group_3"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"auto_groups\":[\"group_3\", \"group_2\"],\"is_service_user\":false, \"is_blocked\": false}"),
},
{
name: "Should_Fail_Because_AutoGroups_Is_Absent",
requestType: http.MethodPut,
requestPath: "/api/users/" + regularUserID,
expectedStatusCode: http.StatusBadRequest,
expectedUserID: regularUserID,
expectedBlocked: false,
expectedRole: "admin",
expectedStatus: "blocked",
expectedGroups: []string{"group_2", "group_3"},
requestBody: bytes.NewBufferString("{\"role\":\"admin\",\"is_service_user\":false, \"is_blocked\": false}"),
},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := recorder.Code; status != tc.expectedStatusCode {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
if tc.expectedStatusCode == 200 {
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
respBody := &api.User{}
err = json.Unmarshal(content, &respBody)
if err != nil {
t.Fatalf("response content is not in correct json format; %v", err)
}
assert.Equal(t, tc.expectedUserID, respBody.Id)
assert.Equal(t, tc.expectedRole, respBody.Role)
assert.Equal(t, tc.expectedIsServiceUser, *respBody.IsServiceUser)
assert.Equal(t, tc.expectedBlocked, respBody.IsBlocked)
assert.Len(t, respBody.AutoGroups, len(tc.expectedGroups))
for _, expectedGroup := range tc.expectedGroups {
exists := false
for _, actualGroup := range respBody.AutoGroups {
if expectedGroup == actualGroup {
exists = true
}
}
assert.True(t, exists, fmt.Sprintf("group %s not found in the response", expectedGroup))
}
}
})
}
}
func TestCreateUser(t *testing.T) {
name := "name"
email := "email"
serviceUserToAdd := api.UserCreateRequest{
AutoGroups: []string{},
Email: nil,
IsServiceUser: true,
Name: &name,
Role: "admin",
}
serviceUserString, err := json.Marshal(serviceUserToAdd)
if err != nil {
t.Fatal(err)
}
regularUserToAdd := api.UserCreateRequest{
AutoGroups: []string{},
Email: &email,
IsServiceUser: true,
Name: &name,
Role: "admin",
}
regularUserString, err := json.Marshal(regularUserToAdd)
if err != nil {
t.Fatal(err)
}
tt := []struct {
name string
expectedStatus int
requestType string
requestPath string
requestBody io.Reader
expectedResult []*server.User
}{
{name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)},
// right now creation is blocked in AC middleware, will be refactored in the future
{name: "CreateRegularUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(regularUserString)},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
rr := httptest.NewRecorder()
userHandler.createUser(rr, req)
res := rr.Result()
defer res.Body.Close()
if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
}
})
}
}
func TestInviteUser(t *testing.T) {
tt := []struct {
name string
expectedStatus int
requestType string
requestPath string
requestVars map[string]string
}{
{
name: "Invite User with Existing User",
requestType: http.MethodPost,
requestPath: "/api/users/" + existingUserID + "/invite",
expectedStatus: http.StatusOK,
requestVars: map[string]string{"userId": existingUserID},
},
{
name: "Invite User with missing user_id",
requestType: http.MethodPost,
requestPath: "/api/users/" + notFoundUserID + "/invite",
expectedStatus: http.StatusNotFound,
requestVars: map[string]string{"userId": notFoundUserID},
},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
rr := httptest.NewRecorder()
userHandler.inviteUser(rr, req)
res := rr.Result()
defer res.Body.Close()
if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
}
})
}
}
func TestDeleteUser(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
requestType string
requestPath string
requestVars map[string]string
requestBody io.Reader
}{
{
name: "Delete Regular User",
requestType: http.MethodDelete,
requestPath: "/api/users/" + regularUserID,
requestVars: map[string]string{"userId": regularUserID},
expectedStatus: http.StatusForbidden,
},
{
name: "Delete Service User",
requestType: http.MethodDelete,
requestPath: "/api/users/" + serviceUserID,
requestVars: map[string]string{"userId": serviceUserID},
expectedStatus: http.StatusOK,
},
{
name: "Delete Not Existing User",
requestType: http.MethodDelete,
requestPath: "/api/users/" + notFoundUserID,
requestVars: map[string]string{"userId": notFoundUserID},
expectedStatus: http.StatusNotFound,
},
}
userHandler := initUsersTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(tc.requestType, tc.requestPath, nil)
req = mux.SetURLVars(req, tc.requestVars)
rr := httptest.NewRecorder()
userHandler.deleteUser(rr, req)
res := rr.Result()
defer res.Body.Close()
if status := rr.Code; status != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
}
})
}
}