diff --git a/go.mod b/go.mod index d6518467f..c74fee409 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/getlantern/systray v1.2.1 github.com/gliderlabs/ssh v0.3.4 github.com/godbus/dbus/v5 v5.1.0 + github.com/google/go-cmp v0.5.9 github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 @@ -89,7 +90,6 @@ require ( github.com/go-stack/stack v1.8.0 // indirect github.com/gobwas/glob v0.2.3 // indirect github.com/goki/freetype v0.0.0-20181231101311-fa8a33aabaff // indirect - github.com/google/go-cmp v0.5.9 // indirect github.com/google/gopacket v1.1.19 // indirect github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/management/server/account.go b/management/server/account.go index 01cae2e64..98326e7ae 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -3,6 +3,7 @@ package server import ( "context" "crypto/sha256" + b64 "encoding/base64" "fmt" "hash/crc32" "math/rand" @@ -54,7 +55,7 @@ type AccountManager interface { GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) - GetAccountFromPAT(pat string) (*Account, *User, error) + GetAccountFromPAT(pat string) (*Account, *User, *PersonalAccessToken, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) GetPeerByKey(peerKey string) (*Peer, error) @@ -66,8 +67,10 @@ type AccountManager interface { GetNetworkMap(peerID string) (*NetworkMap, error) GetPeerNetwork(peerID string) (*Network, error) AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error) - AddPATToUser(accountID string, userID string, pat *PersonalAccessToken) error - DeletePAT(accountID string, userID string, tokenID string) error + CreatePAT(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error + GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) + GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error) @@ -1120,44 +1123,51 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e } // GetAccountFromPAT returns Account and User associated with a personal access token -func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, error) { +func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, *PersonalAccessToken, error) { if len(token) != PATLength { - return nil, nil, fmt.Errorf("token has wrong length") + return nil, nil, nil, fmt.Errorf("token has wrong length") } prefix := token[:len(PATPrefix)] if prefix != PATPrefix { - return nil, nil, fmt.Errorf("token has wrong prefix") + return nil, nil, nil, fmt.Errorf("token has wrong prefix") } secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] verificationChecksum, err := base62.Decode(encodedChecksum) if err != nil { - return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) + return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) } secretChecksum := crc32.ChecksumIEEE([]byte(secret)) if secretChecksum != verificationChecksum { - return nil, nil, fmt.Errorf("token checksum does not match") + return nil, nil, nil, fmt.Errorf("token checksum does not match") } hashedToken := sha256.Sum256([]byte(token)) - tokenID, err := am.Store.GetTokenIDByHashedToken(string(hashedToken[:])) + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) + tokenID, err := am.Store.GetTokenIDByHashedToken(encodedHashedToken) if err != nil { - return nil, nil, err + return nil, nil, nil, err } user, err := am.Store.GetUserByTokenID(tokenID) if err != nil { - return nil, nil, err + return nil, nil, nil, err } account, err := am.Store.GetAccountByUser(user.Id) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - return account, user, nil + + pat := user.PATs[tokenID] + if pat == nil { + return nil, nil, nil, fmt.Errorf("personal access token not found") + } + + return account, user, pat, nil } // GetAccountFromToken returns an account associated with this token diff --git a/management/server/account_test.go b/management/server/account_test.go index af894817b..25d501c8b 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2,6 +2,7 @@ package server import ( "crypto/sha256" + b64 "encoding/base64" "fmt" "net" "reflect" @@ -465,12 +466,13 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) account.Users["someUser"] = &User{ Id: "someUser", PATs: map[string]*PersonalAccessToken{ - "pat1": { + "tokenId": { ID: "tokenId", - HashedToken: string(hashedToken[:]), + HashedToken: encodedHashedToken, }, }, } @@ -483,13 +485,14 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { Store: store, } - account, user, err := am.GetAccountFromPAT(token) + account, user, pat, err := am.GetAccountFromPAT(token) if err != nil { t.Fatalf("Error when getting Account from PAT: %s", err) } assert.Equal(t, "account_id", account.Id) assert.Equal(t, "someUser", user.Id) + assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat) } func TestAccountManager_PrivateAccount(t *testing.T) { @@ -1245,7 +1248,7 @@ func TestAccount_Copy(t *testing.T) { PATs: map[string]*PersonalAccessToken{ "pat1": { ID: "pat1", - Description: "First PAT", + Name: "First PAT", HashedToken: "SoMeHaShEdToKeN", ExpirationDate: time.Now().AddDate(0, 0, 7), CreatedBy: "user1", diff --git a/management/server/file_store.go b/management/server/file_store.go index 4f8092cfb..9b4a9f47f 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -112,7 +112,7 @@ func restore(file string) (*FileStore, error) { store.UserID2AccountID[user.Id] = accountID for _, pat := range user.PATs { store.TokenID2UserID[pat.ID] = user.Id - store.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID + store.HashedPAT2TokenID[pat.HashedToken] = pat.ID } } @@ -268,7 +268,7 @@ func (s *FileStore) SaveAccount(account *Account) error { s.UserID2AccountID[user.Id] = accountCopy.Id for _, pat := range user.PATs { s.TokenID2UserID[pat.ID] = user.Id - s.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID + s.HashedPAT2TokenID[pat.HashedToken] = pat.ID } } diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index b3d954a4d..eaeb5693c 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -6,6 +6,8 @@ info: tags: - name: Users description: Interact with and view information about users. + - name: Tokens + description: Interact with and view information about tokens. - name: Peers description: Interact with and view information about peers. - name: Setup Keys @@ -284,6 +286,59 @@ components: - revoked - auto_groups - usage_limit + PersonalAccessToken: + type: object + properties: + id: + description: ID of a token + type: string + name: + description: Name of the token + type: string + expiration_date: + description: Date the token expires + type: string + format: date-time + created_by: + description: User ID of the user who created the token + type: string + created_at: + description: Date the token was created + type: string + format: date-time + last_used: + description: Date the token was last used + type: string + format: date-time + required: + - id + - name + - expiration_date + - created_by + - created_at + PersonalAccessTokenGenerated: + type: object + properties: + plain_token: + description: Plain text representation of the generated token + type: string + personal_access_token: + $ref: '#/components/schemas/PersonalAccessToken' + required: + - plain_token + - personal_access_token + PersonalAccessTokenRequest: + type: object + properties: + name: + description: Name of the token + type: string + expires_in: + description: Expiration in days + type: integer + required: + - name + - expires_in GroupMinimum: type: object properties: @@ -848,6 +903,133 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/users/{userId}/tokens: + get: + summary: Returns a list of all tokens for a user + tags: [ Tokens ] + security: + - BearerAuth: [] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + responses: + '200': + description: A JSON Array of PersonalAccessTokens + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/PersonalAccessToken' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a new token + tags: [ Tokens ] + security: + - BearerAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + requestBody: + description: PersonalAccessToken create parameters + content: + application/json: + schema: + $ref: '#/components/schemas/PersonalAccessTokenRequest' + responses: + '200': + description: The token in plain text + content: + application/json: + schema: + $ref: '#/components/schemas/PersonalAccessTokenGenerated' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/users/{userId}/tokens/{tokenId}: + get: + summary: Returns a specific token + tags: [ Tokens ] + security: + - BearerAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + - in: path + name: tokenId + required: true + schema: + type: string + description: The Token ID + responses: + '200': + description: A PersonalAccessTokens Object + content: + application/json: + schema: + $ref: '#/components/schemas/PersonalAccessToken' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a token + tags: [ Tokens ] + security: + - BearerAuth: [ ] + parameters: + - in: path + name: userId + required: true + schema: + type: string + description: The User ID + - in: path + name: tokenId + required: true + schema: + type: string + description: The Token ID + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/peers: get: summary: Returns a list of all peers diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 372ecd1a7..930a9df54 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -379,6 +379,44 @@ type PeerMinimum struct { Name string `json:"name"` } +// PersonalAccessToken defines model for PersonalAccessToken. +type PersonalAccessToken struct { + // CreatedAt Date the token was created + CreatedAt time.Time `json:"created_at"` + + // CreatedBy User ID of the user who created the token + CreatedBy string `json:"created_by"` + + // ExpirationDate Date the token expires + ExpirationDate time.Time `json:"expiration_date"` + + // Id ID of a token + Id string `json:"id"` + + // LastUsed Date the token was last used + LastUsed *time.Time `json:"last_used,omitempty"` + + // Name Name of the token + Name string `json:"name"` +} + +// PersonalAccessTokenGenerated defines model for PersonalAccessTokenGenerated. +type PersonalAccessTokenGenerated struct { + PersonalAccessToken PersonalAccessToken `json:"personal_access_token"` + + // PlainToken Plain text representation of the generated token + PlainToken string `json:"plain_token"` +} + +// PersonalAccessTokenRequest defines model for PersonalAccessTokenRequest. +type PersonalAccessTokenRequest struct { + // ExpiresIn Expiration in days + ExpiresIn int `json:"expires_in"` + + // Name Name of the token + Name string `json:"name"` +} + // Policy defines model for Policy. type Policy struct { // Description Policy friendly description @@ -808,3 +846,6 @@ type PostApiUsersJSONRequestBody = UserCreateRequest // PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType. type PutApiUsersIdJSONRequestBody = UserRequest + +// PostApiUsersUserIdTokensJSONRequestBody defines body for PostApiUsersUserIdTokens for application/json ContentType. +type PostApiUsersUserIdTokensJSONRequestBody = PersonalAccessTokenRequest diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index 9712d2e75..2464f47ef 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -300,7 +300,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetGroup returns a group diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 90f62e700..79028e6d2 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -25,6 +25,9 @@ type apiHandler struct { AuthCfg AuthCfg } +type emptyObject struct { +} + // APIHandler creates the Management service HTTP API handler registering all the available endpoints. func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics, authCfg AuthCfg) (http.Handler, error) { jwtMiddleware, err := middleware.NewJwtMiddleware( @@ -57,6 +60,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics api.addAccountsEndpoint() api.addPeersEndpoint() api.addUsersEndpoint() + api.addUsersTokensEndpoint() api.addSetupKeysEndpoint() api.addRulesEndpoint() api.addPoliciesEndpoint() @@ -110,6 +114,14 @@ func (apiHandler *apiHandler) addUsersEndpoint() { apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS") } +func (apiHandler *apiHandler) addUsersTokensEndpoint() { + tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) + apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS") + apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS") + apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS") + apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS") +} + func (apiHandler *apiHandler) addSetupKeysEndpoint() { keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg) apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS") diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index e7be617e2..5ad52a426 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -243,7 +243,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetNameserverGroup handles a nameserver group Get request identified by ID diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go new file mode 100644 index 000000000..d2398a7e1 --- /dev/null +++ b/management/server/http/pat_handler.go @@ -0,0 +1,178 @@ +package http + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/status" +) + +// PATHandler is the nameserver group handler of the account +type PATHandler struct { + accountManager server.AccountManager + claimsExtractor *jwtclaims.ClaimsExtractor +} + +// NewPATsHandler creates a new PATHandler HTTP handler +func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler { + return &PATHandler{ + accountManager: accountManager, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user +func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + userID := vars["userId"] + if len(userID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + + pats, err := h.accountManager.GetAllPATs(account.Id, user.Id, userID) + if err != nil { + util.WriteError(err, w) + return + } + + var patResponse []*api.PersonalAccessToken + for _, pat := range pats { + patResponse = append(patResponse, toPATResponse(pat)) + } + + util.WriteJSONObject(w, patResponse) +} + +// GetToken is HTTP GET handler that returns a personal access token for the given user +func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + + tokenID := vars["tokenId"] + if len(tokenID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) + return + } + + pat, err := h.accountManager.GetPAT(account.Id, user.Id, targetUserID, tokenID) + if err != nil { + util.WriteError(err, w) + return + } + + util.WriteJSONObject(w, toPATResponse(pat)) +} + +// CreateToken is HTTP POST handler that creates a personal access token for the given user +func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + + var req api.PostApiUsersUserIdTokensJSONRequestBody + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + pat, err := h.accountManager.CreatePAT(account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) + if err != nil { + util.WriteError(err, w) + return + } + + util.WriteJSONObject(w, toPATGeneratedResponse(pat)) +} + +// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user +func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + vars := mux.Vars(r) + targetUserID := vars["userId"] + if len(targetUserID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid user ID"), w) + return + } + + tokenID := vars["tokenId"] + if len(tokenID) == 0 { + util.WriteError(status.Errorf(status.InvalidArgument, "invalid token ID"), w) + return + } + + err = h.accountManager.DeletePAT(account.Id, user.Id, targetUserID, tokenID) + if err != nil { + util.WriteError(err, w) + return + } + + util.WriteJSONObject(w, emptyObject{}) +} + +func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { + var lastUsed *time.Time + if !pat.LastUsed.IsZero() { + lastUsed = &pat.LastUsed + } + return &api.PersonalAccessToken{ + CreatedAt: pat.CreatedAt, + CreatedBy: pat.CreatedBy, + Name: pat.Name, + ExpirationDate: pat.ExpirationDate, + Id: pat.ID, + LastUsed: lastUsed, + } +} + +func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { + return &api.PersonalAccessTokenGenerated{ + PlainToken: pat.PlainToken, + PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken), + } +} diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go new file mode 100644 index 000000000..de79f1006 --- /dev/null +++ b/management/server/http/pat_handler_test.go @@ -0,0 +1,254 @@ +package http + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/status" +) + +const ( + existingAccountID = "existingAccountID" + notFoundAccountID = "notFoundAccountID" + existingUserID = "existingUserID" + notFoundUserID = "notFoundUserID" + existingTokenID = "existingTokenID" + notFoundTokenID = "notFoundTokenID" + domain = "hotmail.com" +) + +var testAccount = &server.Account{ + Id: existingAccountID, + Domain: domain, + Users: map[string]*server.User{ + existingUserID: { + Id: existingUserID, + PATs: map[string]*server.PersonalAccessToken{ + existingTokenID: { + ID: existingTokenID, + Name: "My first token", + HashedToken: "someHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: existingUserID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + "token2": { + ID: "token2", + Name: "My second token", + HashedToken: "someOtherHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: existingUserID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + }, + }, + }, +} + +func initPATTestData() *PATHandler { + return &PATHandler{ + accountManager: &mock_server.MockAccountManager{ + CreatePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + if accountID != existingAccountID { + return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if targetUserID != existingUserID { + return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) + } + return &server.PersonalAccessTokenGenerated{ + PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe", + PersonalAccessToken: server.PersonalAccessToken{}, + }, nil + }, + + GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + return testAccount, testAccount.Users[existingUserID], nil + }, + DeletePATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) error { + if accountID != existingAccountID { + return status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if targetUserID != existingUserID { + return status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) + } + if tokenID != existingTokenID { + return status.Errorf(status.NotFound, "token with ID %s not found", tokenID) + } + return nil + }, + GetPATFunc: func(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + if accountID != existingAccountID { + return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if targetUserID != existingUserID { + return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) + } + if tokenID != existingTokenID { + return nil, status.Errorf(status.NotFound, "token with ID %s not found", tokenID) + } + return testAccount.Users[existingUserID].PATs[existingTokenID], nil + }, + GetAllPATsFunc: func(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + if accountID != existingAccountID { + return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) + } + if targetUserID != existingUserID { + return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) + } + return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil + }, + }, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { + return jwtclaims.AuthorizationClaims{ + UserId: existingUserID, + Domain: domain, + AccountId: testNSGroupAccountID, + } + }), + ), + } +} + +func TestTokenHandlers(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + requestType string + requestPath string + requestBody io.Reader + }{ + { + name: "Get All Tokens", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens", + expectedStatus: http.StatusOK, + expectedBody: true, + }, + { + name: "Get Existing Token", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID, + expectedStatus: http.StatusOK, + expectedBody: true, + }, + { + name: "Get Not Existing Token", + requestType: http.MethodGet, + requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID, + expectedStatus: http.StatusNotFound, + }, + { + name: "Delete Existing Token", + requestType: http.MethodDelete, + requestPath: "/api/users/" + existingUserID + "/tokens/" + existingTokenID, + expectedStatus: http.StatusOK, + }, + { + name: "Delete Not Existing Token", + requestType: http.MethodDelete, + requestPath: "/api/users/" + existingUserID + "/tokens/" + notFoundTokenID, + expectedStatus: http.StatusNotFound, + }, + { + name: "POST OK", + requestType: http.MethodPost, + requestPath: "/api/users/" + existingUserID + "/tokens", + requestBody: bytes.NewBuffer( + []byte("{\"name\":\"name\",\"expires_in\":7}")), + expectedStatus: http.StatusOK, + expectedBody: true, + }, + } + + p := initPATTestData() + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + + router := mux.NewRouter() + router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("I don't know what I expected; %v", err) + } + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v, content: %s", + status, tc.expectedStatus, string(content)) + return + } + + if !tc.expectedBody { + return + } + + switch tc.name { + case "POST OK": + got := &api.PersonalAccessTokenGenerated{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.NotEmpty(t, got.PlainToken) + assert.Equal(t, server.PATLength, len(got.PlainToken)) + case "Get All Tokens": + expectedTokens := []api.PersonalAccessToken{ + toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]), + toTokenResponse(*testAccount.Users[existingUserID].PATs["token2"]), + } + + var got []api.PersonalAccessToken + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.True(t, cmp.Equal(got, expectedTokens)) + case "Get Existing Token": + expectedToken := toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]) + got := &api.PersonalAccessToken{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + assert.True(t, cmp.Equal(*got, expectedToken)) + } + + }) + } +} + +func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken { + return api.PersonalAccessToken{ + Id: serverToken.ID, + Name: serverToken.Name, + CreatedAt: serverToken.CreatedAt, + LastUsed: &serverToken.LastUsed, + CreatedBy: serverToken.CreatedBy, + ExpirationDate: serverToken.ExpirationDate, + } +} diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 76c4f7502..7379277af 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -66,7 +66,7 @@ func (h *PeersHandler) deletePeer(accountID, userID string, peerID string, w htt util.WriteError(err, w) return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // HandlePeer handles all peer requests for GET, PUT and DELETE operations diff --git a/management/server/http/policies.go b/management/server/http/policies.go index 275992d14..a0fe3b1e2 100644 --- a/management/server/http/policies.go +++ b/management/server/http/policies.go @@ -225,7 +225,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetPolicy handles a group Get request identified by ID diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index b29e5c261..aaaaaa854 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -321,7 +321,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetRoute handles a route Get request identified by ID diff --git a/management/server/http/rules_handler.go b/management/server/http/rules_handler.go index 8925c3763..f8bb5f0cb 100644 --- a/management/server/http/rules_handler.go +++ b/management/server/http/rules_handler.go @@ -222,7 +222,7 @@ func (h *RulesHandler) DeleteRule(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(w, "") + util.WriteJSONObject(w, emptyObject{}) } // GetRule handles a group Get request identified by ID diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 2ae71d1a9..53cd2d672 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -47,7 +47,7 @@ type MockAccountManager struct { DeletePolicyFunc func(accountID, policyID, userID string) error ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, error) + GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error) @@ -60,8 +60,10 @@ type MockAccountManager struct { SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) - AddPATToUserFunc func(accountID string, userID string, pat *server.PersonalAccessToken) error - DeletePATFunc func(accountID string, userID string, tokenID string) error + CreatePATFunc func(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + DeletePATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) error + GetPATFunc func(accountID string, executingUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) + GetAllPATsFunc func(accountID string, executingUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error @@ -179,29 +181,45 @@ func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*ser } // GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, error) { +func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { if am.GetAccountFromPATFunc != nil { return am.GetAccountFromPATFunc(pat) } - return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") + return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") } -// AddPATToUser mock implementation of AddPATToUser from server.AccountManager interface -func (am *MockAccountManager) AddPATToUser(accountID string, userID string, pat *server.PersonalAccessToken) error { - if am.AddPATToUserFunc != nil { - return am.AddPATToUserFunc(accountID, userID, pat) +// CreatePAT mock implementation of GetPAT from server.AccountManager interface +func (am *MockAccountManager) CreatePAT(accountID string, executingUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + if am.CreatePATFunc != nil { + return am.CreatePATFunc(accountID, executingUserID, targetUserID, name, expiresIn) } - return status.Errorf(codes.Unimplemented, "method AddPATToUser is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method CreatePAT is not implemented") } // DeletePAT mock implementation of DeletePAT from server.AccountManager interface -func (am *MockAccountManager) DeletePAT(accountID string, userID string, tokenID string) error { +func (am *MockAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error { if am.DeletePATFunc != nil { - return am.DeletePATFunc(accountID, userID, tokenID) + return am.DeletePATFunc(accountID, executingUserID, targetUserID, tokenID) } return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented") } +// GetPAT mock implementation of GetPAT from server.AccountManager interface +func (am *MockAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + if am.GetPATFunc != nil { + return am.GetPATFunc(accountID, executingUserID, targetUserID, tokenID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetPAT is not implemented") +} + +// GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface +func (am *MockAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + if am.GetAllPATsFunc != nil { + return am.GetAllPATsFunc(accountID, executingUserID, targetUserID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAllPATs is not implemented") +} + // GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, error) { if am.GetNetworkMapFunc != nil { diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index 7416a9e0b..bdf34e9fd 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -2,6 +2,7 @@ package server import ( "crypto/sha256" + b64 "encoding/base64" "fmt" "hash/crc32" "time" @@ -25,7 +26,7 @@ const ( // PersonalAccessToken holds all information about a PAT including a hashed version of it for verification type PersonalAccessToken struct { ID string - Description string + Name string HashedToken string ExpirationDate time.Time // scope could be added in future @@ -34,23 +35,33 @@ type PersonalAccessToken struct { LastUsed time.Time } +// PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it +type PersonalAccessTokenGenerated struct { + PlainToken string + PersonalAccessToken +} + // CreateNewPAT will generate a new PersonalAccessToken that can be assigned to a User. // Additionally, it will return the token in plain text once, to give to the user and only save a hashed version -func CreateNewPAT(description string, expirationInDays int, createdBy string) (*PersonalAccessToken, string, error) { +func CreateNewPAT(name string, expirationInDays int, createdBy string) (*PersonalAccessTokenGenerated, error) { hashedToken, plainToken, err := generateNewToken() if err != nil { - return nil, "", err + return nil, err } currentTime := time.Now().UTC() - return &PersonalAccessToken{ - ID: xid.New().String(), - Description: description, - HashedToken: hashedToken, - ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), - CreatedBy: createdBy, - CreatedAt: currentTime, - LastUsed: currentTime, - }, plainToken, nil + return &PersonalAccessTokenGenerated{ + PersonalAccessToken: PersonalAccessToken{ + ID: xid.New().String(), + Name: name, + HashedToken: hashedToken, + ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), + CreatedBy: createdBy, + CreatedAt: currentTime, + LastUsed: time.Time{}, + }, + PlainToken: plainToken, + }, nil + } func generateNewToken() (string, string, error) { @@ -64,5 +75,6 @@ func generateNewToken() (string, string, error) { paddedChecksum := fmt.Sprintf("%06s", encodedChecksum) plainToken := PATPrefix + secret + paddedChecksum hashedToken := sha256.Sum256([]byte(plainToken)) - return string(hashedToken[:]), plainToken, nil + encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) + return encodedHashedToken, plainToken, nil } diff --git a/management/server/personal_access_token_test.go b/management/server/personal_access_token_test.go index a4e02f750..03dd2ef4e 100644 --- a/management/server/personal_access_token_test.go +++ b/management/server/personal_access_token_test.go @@ -2,6 +2,7 @@ package server import ( "crypto/sha256" + b64 "encoding/base64" "hash/crc32" "strings" "testing" @@ -13,7 +14,8 @@ import ( func TestPAT_GenerateToken_Hashing(t *testing.T) { hashedToken, plainToken, _ := generateNewToken() expectedToken := sha256.Sum256([]byte(plainToken)) - assert.Equal(t, hashedToken, string(expectedToken[:])) + encodedExpectedToken := b64.StdEncoding.EncodeToString(expectedToken[:]) + assert.Equal(t, hashedToken, encodedExpectedToken) } func TestPAT_GenerateToken_Prefix(t *testing.T) { diff --git a/management/server/user.go b/management/server/user.go index c3011c317..97e9cd0c7 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -193,57 +193,141 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *Us } -// AddPATToUser takes the userID and the accountID the user belongs to and assigns a provided PersonalAccessToken to that user -func (am *DefaultAccountManager) AddPATToUser(accountID string, userID string, pat *PersonalAccessToken) error { +// CreatePAT creates a new PAT for the given user +func (am *DefaultAccountManager) CreatePAT(accountID string, executingUserID string, targetUserId string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() + if tokenName == "" { + return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") + } + + if expiresIn < 1 || expiresIn > 365 { + return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") + } + + if executingUserID != targetUserId { + return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") + } + account, err := am.Store.GetAccount(accountID) if err != nil { - return err + return nil, err } - user := account.Users[userID] - if user == nil { - return status.Errorf(status.NotFound, "user not found") + targetUser := account.Users[targetUserId] + if targetUser == nil { + return nil, status.Errorf(status.NotFound, "targetUser not found") } - user.PATs[pat.ID] = pat + pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id) + if err != nil { + return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) + } - return am.Store.SaveAccount(account) + targetUser.PATs[pat.ID] = &pat.PersonalAccessToken + + err = am.Store.SaveAccount(account) + if err != nil { + return nil, status.Errorf(status.Internal, "failed to save account: %v", err) + } + + return pat, nil } // DeletePAT deletes a specific PAT from a user -func (am *DefaultAccountManager) DeletePAT(accountID string, userID string, tokenID string) error { +func (am *DefaultAccountManager) DeletePAT(accountID string, executingUserID string, targetUserID string, tokenID string) error { unlock := am.Store.AcquireAccountLock(accountID) defer unlock() - account, err := am.Store.GetAccount(accountID) - if err != nil { - return err + if executingUserID != targetUserID { + return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user") } - user := account.Users[userID] + account, err := am.Store.GetAccount(accountID) + if err != nil { + return status.Errorf(status.NotFound, "account not found: %s", err) + } + + user := account.Users[targetUserID] if user == nil { return status.Errorf(status.NotFound, "user not found") } - pat := user.PATs["tokenID"] + pat := user.PATs[tokenID] if pat == nil { return status.Errorf(status.NotFound, "PAT not found") } err = am.Store.DeleteTokenID2UserIDIndex(pat.ID) if err != nil { - return err + return status.Errorf(status.Internal, "Failed to delete token id index: %s", err) } err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken) if err != nil { - return err + return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err) } delete(user.PATs, tokenID) - return am.Store.SaveAccount(account) + err = am.Store.SaveAccount(account) + if err != nil { + return status.Errorf(status.Internal, "Failed to save account: %s", err) + } + return nil +} + +// GetPAT returns a specific PAT from a user +func (am *DefaultAccountManager) GetPAT(accountID string, executingUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + if executingUserID != targetUserID { + return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") + } + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, status.Errorf(status.NotFound, "account not found: %s", err) + } + + user := account.Users[targetUserID] + if user == nil { + return nil, status.Errorf(status.NotFound, "user not found") + } + + pat := user.PATs[tokenID] + if pat == nil { + return nil, status.Errorf(status.NotFound, "PAT not found") + } + + return pat, nil +} + +// GetAllPATs returns all PATs for a user +func (am *DefaultAccountManager) GetAllPATs(accountID string, executingUserID string, targetUserID string) ([]*PersonalAccessToken, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + if executingUserID != targetUserID { + return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") + } + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, status.Errorf(status.NotFound, "account not found: %s", err) + } + + user := account.Users[targetUserID] + if user == nil { + return nil, status.Errorf(status.NotFound, "user not found") + } + + var pats []*PersonalAccessToken + for _, pat := range user.PATs { + pats = append(pats, pat) + } + + return pats, nil } // SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error. diff --git a/management/server/user_test.go b/management/server/user_test.go index 20f2ca4f1..238aa2bff 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -7,13 +7,20 @@ import ( ) const ( - mockAccountID = "accountID" - mockUserID = "userID" - mockTokenID = "tokenID" - mockToken = "SoMeHaShEdToKeN" + mockAccountID = "accountID" + mockUserID = "userID" + mockTargetUserId = "targetUserID" + mockTokenID1 = "tokenID1" + mockToken1 = "SoMeHaShEdToKeN1" + mockTokenID2 = "tokenID2" + mockToken2 = "SoMeHaShEdToKeN2" + mockTokenName = "tokenName" + mockEmptyTokenName = "" + mockExpiresIn = 7 + mockWrongExpiresIn = 4506 ) -func TestUser_AddPATToUser(t *testing.T) { +func TestUser_CreatePAT_ForSameUser(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") @@ -26,24 +33,19 @@ func TestUser_AddPATToUser(t *testing.T) { Store: store, } - pat := PersonalAccessToken{ - ID: mockTokenID, - HashedToken: mockToken, - } - - err = am.AddPATToUser(mockAccountID, mockUserID, &pat) + pat, err := am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } fileStore := am.Store.(*FileStore) - tokenID := fileStore.HashedPAT2TokenID[mockToken[:]] + tokenID := fileStore.HashedPAT2TokenID[pat.HashedToken] if tokenID == "" { t.Fatal("GetTokenIDByHashedToken failed after adding PAT") } - assert.Equal(t, mockTokenID, tokenID) + assert.Equal(t, pat.ID, tokenID) userID := fileStore.TokenID2UserID[tokenID] if userID == "" { @@ -52,15 +54,66 @@ func TestUser_AddPATToUser(t *testing.T) { assert.Equal(t, mockUserID, userID) } +func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + _, err = am.CreatePAT(mockAccountID, mockUserID, mockTargetUserId, mockTokenName, mockExpiresIn) + assert.Errorf(t, err, "Creating PAT for different user should thorw error") +} + +func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockTokenName, mockWrongExpiresIn) + assert.Errorf(t, err, "Wrong expiration should thorw error") +} + +func TestUser_CreatePAT_WithEmptyName(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + _, err = am.CreatePAT(mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) + assert.Errorf(t, err, "Wrong expiration should thorw error") +} + func TestUser_DeletePAT(t *testing.T) { store := newStore(t) account := newAccountWithId(mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ Id: mockUserID, PATs: map[string]*PersonalAccessToken{ - mockTokenID: { - ID: mockTokenID, - HashedToken: mockToken, + mockTokenID1: { + ID: mockTokenID1, + HashedToken: mockToken1, }, }, } @@ -73,12 +126,75 @@ func TestUser_DeletePAT(t *testing.T) { Store: store, } - err = am.DeletePAT(mockAccountID, mockUserID, mockTokenID) + err = am.DeletePAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } - assert.Nil(t, store.Accounts[mockAccountID].Users[mockUserID].PATs[mockTokenID]) - assert.Empty(t, store.HashedPAT2TokenID[mockToken]) - assert.Empty(t, store.TokenID2UserID[mockTokenID]) + assert.Nil(t, store.Accounts[mockAccountID].Users[mockUserID].PATs[mockTokenID1]) + assert.Empty(t, store.HashedPAT2TokenID[mockToken1]) + assert.Empty(t, store.TokenID2UserID[mockTokenID1]) +} + +func TestUser_GetPAT(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + account.Users[mockUserID] = &User{ + Id: mockUserID, + PATs: map[string]*PersonalAccessToken{ + mockTokenID1: { + ID: mockTokenID1, + HashedToken: mockToken1, + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + pat, err := am.GetPAT(mockAccountID, mockUserID, mockUserID, mockTokenID1) + if err != nil { + t.Fatalf("Error when adding PAT to user: %s", err) + } + + assert.Equal(t, mockTokenID1, pat.ID) + assert.Equal(t, mockToken1, pat.HashedToken) +} + +func TestUser_GetAllPATs(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + account.Users[mockUserID] = &User{ + Id: mockUserID, + PATs: map[string]*PersonalAccessToken{ + mockTokenID1: { + ID: mockTokenID1, + HashedToken: mockToken1, + }, + mockTokenID2: { + ID: mockTokenID2, + HashedToken: mockToken2, + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + pats, err := am.GetAllPATs(mockAccountID, mockUserID, mockUserID) + if err != nil { + t.Fatalf("Error when adding PAT to user: %s", err) + } + + assert.Equal(t, 2, len(pats)) }