diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index c9a373411..2668198c4 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -317,6 +317,17 @@ components: - created_by - created_at - last_used + PersonalAccessTokenGenerated: + type: object + properties: + plain_token: + description: Plain text representation of the generated token + type: string + personal_access_token: + $ref: '#/components/schemas/PersonalAccessToken' + required: + - plain_token + - personal_access_token PersonalAccessTokenRequest: type: object properties: @@ -945,9 +956,9 @@ paths: '200': description: The token in plain text content: - text/plain: + application/json: schema: - type: string + $ref: '#/components/schemas/PersonalAccessTokenGenerated' '400': "$ref": "#/components/responses/bad_request" '401': diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 4727e471e..24abaf829 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -400,6 +400,14 @@ type PersonalAccessToken struct { Name string `json:"name"` } +// PersonalAccessTokenGenerated defines model for PersonalAccessTokenGenerated. +type PersonalAccessTokenGenerated struct { + PersonalAccessToken PersonalAccessToken `json:"personal_access_token"` + + // PlainToken Plain text representation of the generated token + PlainToken string `json:"plain_token"` +} + // PersonalAccessTokenRequest defines model for PersonalAccessTokenRequest. type PersonalAccessTokenRequest struct { // ExpiresIn Expiration in days diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 034caa4e9..ab356b8c9 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -81,7 +81,18 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { return } - pat := account.Users[userID].PATs[tokenID] + user = account.Users[userID] + if user == nil { + util.WriteError(status.Errorf(status.NotFound, "user not found"), w) + return + } + + pat := user.PATs[tokenID] + if pat == nil { + util.WriteError(status.Errorf(status.NotFound, "PAT not found"), w) + return + } + util.WriteJSONObject(w, toPATResponse(pat)) } @@ -121,14 +132,14 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } - pat, plainToken, err := server.CreateNewPAT(req.Name, req.ExpiresIn, user.Id) - err = h.accountManager.AddPATToUser(account.Id, userID, pat) + pat, err := server.CreateNewPAT(req.Name, req.ExpiresIn, user.Id) + err = h.accountManager.AddPATToUser(account.Id, userID, &pat.PersonalAccessToken) if err != nil { util.WriteError(err, w) return } - util.WriteJSONObject(w, plainToken) + util.WriteJSONObject(w, toPATGeneratedResponse(pat)) } func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { @@ -175,3 +186,10 @@ func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { LastUsed: pat.LastUsed, } } + +func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { + return &api.PersonalAccessTokenGenerated{ + PlainToken: pat.PlainToken, + PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken), + } +} diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index 3d32e7c30..c0a7185cb 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -1,37 +1,231 @@ package http import ( + "bytes" + "encoding/json" + "fmt" + "io" "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" ) +const ( + existingAccountID = "existingAccountID" + notFoundAccountID = "notFoundAccountID" + existingUserID = "existingUserID" + notFoundUserID = "notFoundUserID" + existingTokenID = "existingTokenID" + notFoundTokenID = "notFoundTokenID" + domain = "hotmail.com" +) + +var testAccount = &server.Account{ + Id: existingAccountID, + Domain: domain, + Users: map[string]*server.User{ + existingUserID: { + Id: existingUserID, + PATs: map[string]*server.PersonalAccessToken{ + existingTokenID: { + ID: existingTokenID, + Name: "My first token", + HashedToken: "someHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: existingUserID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + "token2": { + ID: "token2", + Name: "My second token", + HashedToken: "someOtherHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: existingUserID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + }, + }, + }, +} + func initPATTestData() *PATHandler { return &PATHandler{ accountManager: &mock_server.MockAccountManager{ - AddPATToUserFunc: func(accountID string, userID string, pat *server.PersonalAccessToken) error { - if nsGroupID == existingNSGroupID { - return baseExistingNSGroup.Copy(), nil + if accountID != existingAccountID { + return status.Errorf(status.NotFound, "account with ID %s not found", accountID) } - return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) + if userID != existingUserID { + return status.Errorf(status.NotFound, "user with ID %s not found", userID) + } + return nil }, GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingNSAccount, testingAccount.Users["test_user"], nil + return testAccount, testAccount.Users[existingUserID], nil + }, + DeletePATFunc: func(accountID string, userID string, tokenID string) error { + if accountID != existingAccountID { + return status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if userID != existingUserID { + return status.Errorf(status.NotFound, "user with ID %s not found", userID) + } + if tokenID != existingTokenID { + return status.Errorf(status.NotFound, "token with ID %s not found", tokenID) + } + return nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ - UserId: "test_user", - Domain: "hotmail.com", + UserId: existingUserID, + Domain: domain, AccountId: testNSGroupAccountID, } }), ), } } + +func TestTokenHandlers(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestType string + requestPath string + requestBody io.Reader + }{ + { + name: "Get All Tokens", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens", + expectedStatus: http.StatusOK, + expectedBody: true, + }, + { + name: "Get Existing Token", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID, + expectedStatus: http.StatusOK, + expectedBody: true, + }, + { + name: "Get Not Existing Token", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID, + expectedStatus: http.StatusNotFound, + }, + { + name: "Delete Existing Token", + requestType: http.MethodDelete, + requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID, + expectedStatus: http.StatusOK, + }, + { + name: "Delete Not Existing Token", + requestType: http.MethodDelete, + requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID, + expectedStatus: http.StatusNotFound, + }, + { + name: "POST OK", + requestType: http.MethodPost, + requestPath: "/api/users/" + existingUserID + "/tokens", + requestBody: bytes.NewBuffer( + []byte(fmt.Sprint("{\"name\":\"name\",\"expires_in\":7}"))), + expectedStatus: http.StatusOK, + expectedBody: true, + }, + } + + p := initPATTestData() + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + + router := mux.NewRouter() + router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("I don't know what I expected; %v", err) + } + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v, content: %s", + status, tc.expectedStatus, string(content)) + return + } + + if !tc.expectedBody { + return + } + + switch tc.name { + case "POST OK": + got := &api.PersonalAccessTokenGenerated{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.NotEmpty(t, got.PlainToken) + assert.Equal(t, server.PATLength, len(got.PlainToken)) + case "Get All Tokens": + expectedTokens := []api.PersonalAccessToken{ + toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]), + toTokenResponse(*testAccount.Users[existingUserID].PATs["token2"]), + } + + var got []api.PersonalAccessToken + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.True(t, cmp.Equal(got, expectedTokens)) + case "Get Existing Token": + expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]) + got := &api.PersonalAccessToken{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.True(t, cmp.Equal(*got, expectedToken)) + } + + }) + } +} + +func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken { + return api.PersonalAccessToken{ + Id: serverToken.ID, + Name: serverToken.Name, + CreatedAt: serverToken.CreatedAt, + LastUsed: serverToken.LastUsed, + CreatedBy: serverToken.CreatedBy, + ExpirationDate: serverToken.ExpirationDate, + } +} diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index 817605dce..6eab840b8 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -34,23 +34,33 @@ type PersonalAccessToken struct { LastUsed time.Time } +// PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it +type PersonalAccessTokenGenerated struct { + PlainToken string + PersonalAccessToken +} + // CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User. // Additionally, it will return the token in plain text once, to give to the user and only save a hashed version -func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessToken, string, error) { +func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) { hashedToken, plainToken, err := generateNewToken() if err != nil { - return nil, "", err + return nil, err } currentTime := time.Now().UTC() - return &PersonalAccessToken{ - ID: xid.New().String(), - Name: name, - HashedToken: hashedToken, - ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), - CreatedBy: createdBy, - CreatedAt: currentTime, - LastUsed: currentTime, - }, plainToken, nil + return &PersonalAccessTokenGenerated{ + PersonalAccessToken: PersonalAccessToken{ + ID: xid.New().String(), + Name: name, + HashedToken: hashedToken, + ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), + CreatedBy: createdBy, + CreatedAt: currentTime, + LastUsed: currentTime, + }, + PlainToken: plainToken, + }, nil + } func generateNewToken() (string, string, error) {