mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
[management] restructure api files (#3013)
This commit is contained in:
255
management/server/http/handlers/users/pat_handler_test.go
Normal file
255
management/server/http/handlers/users/pat_handler_test.go
Normal file
@@ -0,0 +1,255 @@
|
||||
package users
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
const (
|
||||
existingAccountID = "existingAccountID"
|
||||
notFoundAccountID = "notFoundAccountID"
|
||||
existingUserID = "existingUserID"
|
||||
notFoundUserID = "notFoundUserID"
|
||||
existingTokenID = "existingTokenID"
|
||||
notFoundTokenID = "notFoundTokenID"
|
||||
testDomain = "hotmail.com"
|
||||
)
|
||||
|
||||
var testAccount = &server.Account{
|
||||
Id: existingAccountID,
|
||||
Domain: testDomain,
|
||||
Users: map[string]*server.User{
|
||||
existingUserID: {
|
||||
Id: existingUserID,
|
||||
PATs: map[string]*server.PersonalAccessToken{
|
||||
existingTokenID: {
|
||||
ID: existingTokenID,
|
||||
Name: "My first token",
|
||||
HashedToken: "someHash",
|
||||
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
|
||||
CreatedBy: existingUserID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: time.Now().UTC(),
|
||||
},
|
||||
"token2": {
|
||||
ID: "token2",
|
||||
Name: "My second token",
|
||||
HashedToken: "someOtherHash",
|
||||
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
|
||||
CreatedBy: existingUserID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: time.Now().UTC(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func initPATTestData() *patHandler {
|
||||
return &patHandler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
return &server.PersonalAccessTokenGenerated{
|
||||
PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe",
|
||||
PersonalAccessToken: server.PersonalAccessToken{},
|
||||
}, nil
|
||||
},
|
||||
|
||||
GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
|
||||
return claims.AccountId, claims.UserId, nil
|
||||
},
|
||||
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
if accountID != existingAccountID {
|
||||
return status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
if tokenID != existingTokenID {
|
||||
return status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
if tokenID != existingTokenID {
|
||||
return nil, status.Errorf(status.NotFound, "token with ID %s not found", tokenID)
|
||||
}
|
||||
return testAccount.Users[existingUserID].PATs[existingTokenID], nil
|
||||
},
|
||||
GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) {
|
||||
if accountID != existingAccountID {
|
||||
return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID)
|
||||
}
|
||||
if targetUserID != existingUserID {
|
||||
return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID)
|
||||
}
|
||||
return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
return jwtclaims.AuthorizationClaims{
|
||||
UserId: existingUserID,
|
||||
Domain: testDomain,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenHandlers(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
requestType string
|
||||
requestPath string
|
||||
requestBody io.Reader
|
||||
}{
|
||||
{
|
||||
name: "Get All Tokens",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens",
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Get Existing Token",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
{
|
||||
name: "Get Not Existing Token",
|
||||
requestType: http.MethodGet,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "Delete Existing Token",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID,
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "Delete Not Existing Token",
|
||||
requestType: http.MethodDelete,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID,
|
||||
expectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "POST OK",
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/users/" + existingUserID + "/tokens",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte("{\"name\":\"name\",\"expires_in\":7}")),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
},
|
||||
}
|
||||
|
||||
p := initPATTestData()
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
|
||||
content, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("I don't know what I expected; %v", err)
|
||||
}
|
||||
|
||||
if status := recorder.Code; status != tc.expectedStatus {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
|
||||
status, tc.expectedStatus, string(content))
|
||||
return
|
||||
}
|
||||
|
||||
if !tc.expectedBody {
|
||||
return
|
||||
}
|
||||
|
||||
switch tc.name {
|
||||
case "POST OK":
|
||||
got := &api.PersonalAccessTokenGenerated{}
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.NotEmpty(t, got.PlainToken)
|
||||
assert.Equal(t, server.PATLength, len(got.PlainToken))
|
||||
case "Get All Tokens":
|
||||
expectedTokens := []api.PersonalAccessToken{
|
||||
toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]),
|
||||
toTokenResponse(*testAccount.Users[existingUserID].PATs["token2"]),
|
||||
}
|
||||
|
||||
var got []api.PersonalAccessToken
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
assert.True(t, cmp.Equal(got, expectedTokens))
|
||||
case "Get Existing Token":
|
||||
expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID])
|
||||
got := &api.PersonalAccessToken{}
|
||||
if err = json.Unmarshal(content, &got); err != nil {
|
||||
t.Fatalf("Sent content is not in correct json format; %v", err)
|
||||
}
|
||||
|
||||
assert.True(t, cmp.Equal(*got, expectedToken))
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken {
|
||||
return api.PersonalAccessToken{
|
||||
Id: serverToken.ID,
|
||||
Name: serverToken.Name,
|
||||
CreatedAt: serverToken.CreatedAt,
|
||||
LastUsed: &serverToken.LastUsed,
|
||||
CreatedBy: serverToken.CreatedBy,
|
||||
ExpirationDate: serverToken.ExpirationDate,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user