diff --git a/management/server/account.go b/management/server/account.go index 5b9c9402d..1d4c10721 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2,7 +2,9 @@ package server import ( "context" + "crypto/sha256" "fmt" + "hash/crc32" "math/rand" "net" "net/netip" @@ -12,17 +14,19 @@ import ( "sync" "time" + "codeberg.org/ac/base62" "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" + gocache "github.com/patrickmn/go-cache" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/route" - gocache "github.com/patrickmn/go-cache" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" ) const ( @@ -50,6 +54,7 @@ type AccountManager interface { GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) + GetAccountFromPAT(pat string) (*Account, *User, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) GetPeerByKey(peerKey string) (*Peer, error) @@ -61,6 +66,8 @@ type AccountManager interface { GetNetworkMap(peerID string) (*NetworkMap, error) GetPeerNetwork(peerID string) (*Network, error) AddPeer(setupKey, userID string, peer *Peer) (*Peer, *NetworkMap, error) + AddPATToUser(accountID string, userID string, pat *PersonalAccessToken) error + DeletePAT(accountID string, userID string, tokenID string) error UpdatePeerSSHKey(peerID string, sshKey string) error GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error) @@ -1112,6 +1119,47 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e return nil } +// GetAccountFromPAT returns Account and User associated with a personal access token +func (am *DefaultAccountManager) GetAccountFromPAT(token string) (*Account, *User, error) { + if len(token) != PATLength { + return nil, nil, fmt.Errorf("token has wrong length") + } + + prefix := token[:len(PATPrefix)] + if prefix != PATPrefix { + return nil, nil, fmt.Errorf("token has wrong prefix") + } + secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] + encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] + + verificationChecksum, err := base62.Decode(encodedChecksum) + if err != nil { + return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) + } + + secretChecksum := crc32.ChecksumIEEE([]byte(secret)) + if secretChecksum != verificationChecksum { + return nil, nil, fmt.Errorf("token checksum does not match") + } + + hashedToken := sha256.Sum256([]byte(token)) + tokenID, err := am.Store.GetTokenIDByHashedToken(string(hashedToken[:])) + if err != nil { + return nil, nil, err + } + + user, err := am.Store.GetUserByTokenID(tokenID) + if err != nil { + return nil, nil, err + } + + account, err := am.Store.GetAccountByUser(user.Id) + if err != nil { + return nil, nil, err + } + return account, user, nil +} + // GetAccountFromToken returns an account associated with this token func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { if claims.UserId == "" { diff --git a/management/server/account_test.go b/management/server/account_test.go index e40d5e5b8..5b4b1cc17 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1,6 +1,7 @@ package server import ( + "crypto/sha256" "fmt" "net" "reflect" @@ -458,6 +459,39 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { } } +func TestAccountManager_GetAccountFromPAT(t *testing.T) { + store := newStore(t) + account := newAccountWithId("account_id", "testuser", "") + + token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" + hashedToken := sha256.Sum256([]byte(token)) + account.Users["someUser"] = &User{ + Id: "someUser", + PATs: map[string]*PersonalAccessToken{ + "pat1": { + ID: "tokenId", + HashedToken: string(hashedToken[:]), + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + account, user, err := am.GetAccountFromPAT(token) + if err != nil { + t.Fatalf("Error when getting Account from PAT: %s", err) + } + + assert.Equal(t, "account_id", account.Id) + assert.Equal(t, "someUser", user.Id) +} + func TestAccountManager_PrivateAccount(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -1208,8 +1242,8 @@ func TestAccount_Copy(t *testing.T) { Id: "user1", Role: UserRoleAdmin, AutoGroups: []string{"group1"}, - PATs: []PersonalAccessToken{ - { + PATs: map[string]*PersonalAccessToken{ + "pat1": { ID: "pat1", Description: "First PAT", HashedToken: "SoMeHaShEdToKeN", diff --git a/management/server/file_store.go b/management/server/file_store.go index e3ec5af44..4f8092cfb 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -7,10 +7,11 @@ import ( "sync" "time" - "github.com/netbirdio/netbird/management/server/status" "github.com/rs/xid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/util" ) @@ -25,6 +26,8 @@ type FileStore struct { PeerID2AccountID map[string]string `json:"-"` UserID2AccountID map[string]string `json:"-"` PrivateDomain2AccountID map[string]string `json:"-"` + HashedPAT2TokenID map[string]string `json:"-"` + TokenID2UserID map[string]string `json:"-"` InstallationID string // mutex to synchronise Store read/write operations @@ -57,6 +60,8 @@ func restore(file string) (*FileStore, error) { UserID2AccountID: make(map[string]string), PrivateDomain2AccountID: make(map[string]string), PeerID2AccountID: make(map[string]string), + HashedPAT2TokenID: make(map[string]string), + TokenID2UserID: make(map[string]string), storeFile: file, } @@ -80,6 +85,8 @@ func restore(file string) (*FileStore, error) { store.UserID2AccountID = make(map[string]string) store.PrivateDomain2AccountID = make(map[string]string) store.PeerID2AccountID = make(map[string]string) + store.HashedPAT2TokenID = make(map[string]string) + store.TokenID2UserID = make(map[string]string) for accountID, account := range store.Accounts { if account.Settings == nil { @@ -103,9 +110,10 @@ func restore(file string) (*FileStore, error) { } for _, user := range account.Users { store.UserID2AccountID[user.Id] = accountID - } - for _, user := range account.Users { - store.UserID2AccountID[user.Id] = accountID + for _, pat := range user.PATs { + store.TokenID2UserID[pat.ID] = user.Id + store.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID + } } if account.Domain != "" && account.DomainCategory == PrivateCategory && @@ -258,6 +266,10 @@ func (s *FileStore) SaveAccount(account *Account) error { for _, user := range accountCopy.Users { s.UserID2AccountID[user.Id] = accountCopy.Id + for _, pat := range user.PATs { + s.TokenID2UserID[pat.ID] = user.Id + s.HashedPAT2TokenID[pat.HashedToken[:]] = pat.ID + } } if accountCopy.DomainCategory == PrivateCategory && accountCopy.IsDomainPrimaryAccount { @@ -276,13 +288,33 @@ func (s *FileStore) SaveAccount(account *Account) error { return s.persist(s.storeFile) } +// DeleteHashedPAT2TokenIDIndex removes an entry from the indexing map HashedPAT2TokenID +func (s *FileStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { + s.mux.Lock() + defer s.mux.Unlock() + + delete(s.HashedPAT2TokenID, hashedToken) + + return s.persist(s.storeFile) +} + +// DeleteTokenID2UserIDIndex removes an entry from the indexing map TokenID2UserID +func (s *FileStore) DeleteTokenID2UserIDIndex(tokenID string) error { + s.mux.Lock() + defer s.mux.Unlock() + + delete(s.TokenID2UserID, tokenID) + + return s.persist(s.storeFile) +} + // GetAccountByPrivateDomain returns account by private domain func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() - accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)] - if !accountIDFound { + accountID, ok := s.PrivateDomain2AccountID[strings.ToLower(domain)] + if !ok { return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") } @@ -299,8 +331,8 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() - accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] - if !accountIDFound { + accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] + if !ok { return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") } @@ -312,6 +344,42 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { return account.Copy(), nil } +// GetTokenIDByHashedToken returns the id of a personal access token by its hashed secret +func (s *FileStore) GetTokenIDByHashedToken(token string) (string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + tokenID, ok := s.HashedPAT2TokenID[token] + if !ok { + return "", status.Errorf(status.NotFound, "tokenID not found: provided token doesn't exists") + } + + return tokenID, nil +} + +// GetUserByTokenID returns a User object a tokenID belongs to +func (s *FileStore) GetUserByTokenID(tokenID string) (*User, error) { + s.mux.Lock() + defer s.mux.Unlock() + + userID, ok := s.TokenID2UserID[tokenID] + if !ok { + return nil, status.Errorf(status.NotFound, "user not found: provided tokenID doesn't exists") + } + + accountID, ok := s.UserID2AccountID[userID] + if !ok { + return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") + } + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + return account.Users[userID].Copy(), nil +} + // GetAllAccounts returns all accounts func (s *FileStore) GetAllAccounts() (all []*Account) { s.mux.Lock() @@ -325,8 +393,8 @@ func (s *FileStore) GetAllAccounts() (all []*Account) { // getAccount returns a reference to the Account. Should not return a copy. func (s *FileStore) getAccount(accountID string) (*Account, error) { - account, accountFound := s.Accounts[accountID] - if !accountFound { + account, ok := s.Accounts[accountID] + if !ok { return nil, status.Errorf(status.NotFound, "account not found") } @@ -351,8 +419,8 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() - accountID, accountIDFound := s.UserID2AccountID[userID] - if !accountIDFound { + accountID, ok := s.UserID2AccountID[userID] + if !ok { return nil, status.Errorf(status.NotFound, "account not found") } @@ -369,8 +437,8 @@ func (s *FileStore) GetAccountByPeerID(peerID string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() - accountID, accountIDFound := s.PeerID2AccountID[peerID] - if !accountIDFound { + accountID, ok := s.PeerID2AccountID[peerID] + if !ok { return nil, status.Errorf(status.NotFound, "provided peer ID doesn't exists %s", peerID) } @@ -395,8 +463,8 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() - accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey] - if !accountIDFound { + accountID, ok := s.PeerKeyID2AccountID[peerKey] + if !ok { return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey) } diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index 287f043d9..b2e9bff29 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -1,14 +1,16 @@ package server import ( + "crypto/sha256" "net" "path/filepath" "testing" "time" - "github.com/netbirdio/netbird/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/util" ) type accounts struct { @@ -71,6 +73,14 @@ func TestNewStore(t *testing.T) { if store.UserID2AccountID == nil || len(store.UserID2AccountID) != 0 { t.Errorf("expected to create a new empty UserID2AccountID map when creating a new FileStore") } + + if store.HashedPAT2TokenID == nil || len(store.HashedPAT2TokenID) != 0 { + t.Errorf("expected to create a new empty HashedPAT2TokenID map when creating a new FileStore") + } + + if store.TokenID2UserID == nil || len(store.TokenID2UserID) != 0 { + t.Errorf("expected to create a new empty TokenID2UserID map when creating a new FileStore") + } } func TestSaveAccount(t *testing.T) { @@ -239,11 +249,17 @@ func TestRestore(t *testing.T) { require.NotNil(t, account.SetupKeys["A2C8E62B-38F5-4553-B31E-DD66C696CEBB"], "failed to restore a FileStore file - missing Account SetupKey A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NotNil(t, account.Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"], "failed to restore a FileStore wrong PATs length") + require.Len(t, store.UserID2AccountID, 2, "failed to restore a FileStore wrong UserID2AccountID mapping length") require.Len(t, store.SetupKeyID2AccountID, 1, "failed to restore a FileStore wrong SetupKeyID2AccountID mapping length") require.Len(t, store.PrivateDomain2AccountID, 1, "failed to restore a FileStore wrong PrivateDomain2AccountID mapping length") + + require.Len(t, store.HashedPAT2TokenID, 1, "failed to restore a FileStore wrong HashedPAT2TokenID mapping length") + + require.Len(t, store.TokenID2UserID, 1, "failed to restore a FileStore wrong TokenID2UserID mapping length") } func TestRestorePolicies_Migration(t *testing.T) { @@ -348,6 +364,137 @@ func TestFileStore_GetAccount(t *testing.T) { assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups)) } +func TestFileStore_GetTokenIDByHashedToken(t *testing.T) { + storeDir := t.TempDir() + storeFile := filepath.Join(storeDir, "store.json") + err := util.CopyFileContents("testdata/store.json", storeFile) + if err != nil { + t.Fatal(err) + } + + accounts := &accounts{} + _, err = util.ReadJson(storeFile, accounts) + if err != nil { + t.Fatal(err) + } + + store, err := NewFileStore(storeDir) + if err != nil { + t.Fatal(err) + } + + hashedToken := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].HashedToken + tokenID, err := store.GetTokenIDByHashedToken(hashedToken) + if err != nil { + t.Fatal(err) + } + + expectedTokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID + assert.Equal(t, expectedTokenID, tokenID) +} + +func TestFileStore_DeleteHashedPAT2TokenIDIndex(t *testing.T) { + store := newStore(t) + store.HashedPAT2TokenID["someHashedToken"] = "someTokenId" + + err := store.DeleteHashedPAT2TokenIDIndex("someHashedToken") + if err != nil { + t.Fatal(err) + } + + assert.Empty(t, store.HashedPAT2TokenID["someHashedToken"]) +} + +func TestFileStore_DeleteTokenID2UserIDIndex(t *testing.T) { + store := newStore(t) + store.TokenID2UserID["someTokenId"] = "someUserId" + + err := store.DeleteTokenID2UserIDIndex("someTokenId") + if err != nil { + t.Fatal(err) + } + + assert.Empty(t, store.TokenID2UserID["someTokenId"]) +} + +func TestFileStore_GetTokenIDByHashedToken_Failure(t *testing.T) { + storeDir := t.TempDir() + storeFile := filepath.Join(storeDir, "store.json") + err := util.CopyFileContents("testdata/store.json", storeFile) + if err != nil { + t.Fatal(err) + } + + accounts := &accounts{} + _, err = util.ReadJson(storeFile, accounts) + if err != nil { + t.Fatal(err) + } + + store, err := NewFileStore(storeDir) + if err != nil { + t.Fatal(err) + } + + wrongToken := sha256.Sum256([]byte("someNotValidTokenThatFails1234")) + _, err = store.GetTokenIDByHashedToken(string(wrongToken[:])) + + assert.Error(t, err, "GetTokenIDByHashedToken should throw error if token invalid") +} + +func TestFileStore_GetUserByTokenID(t *testing.T) { + storeDir := t.TempDir() + storeFile := filepath.Join(storeDir, "store.json") + err := util.CopyFileContents("testdata/store.json", storeFile) + if err != nil { + t.Fatal(err) + } + + accounts := &accounts{} + _, err = util.ReadJson(storeFile, accounts) + if err != nil { + t.Fatal(err) + } + + store, err := NewFileStore(storeDir) + if err != nil { + t.Fatal(err) + } + + tokenID := accounts.Accounts["bf1c8084-ba50-4ce7-9439-34653001fc3b"].Users["f4f6d672-63fb-11ec-90d6-0242ac120003"].PATs["9dj38s35-63fb-11ec-90d6-0242ac120003"].ID + user, err := store.GetUserByTokenID(tokenID) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id) +} + +func TestFileStore_GetUserByTokenID_Failure(t *testing.T) { + storeDir := t.TempDir() + storeFile := filepath.Join(storeDir, "store.json") + err := util.CopyFileContents("testdata/store.json", storeFile) + if err != nil { + t.Fatal(err) + } + + accounts := &accounts{} + _, err = util.ReadJson(storeFile, accounts) + if err != nil { + t.Fatal(err) + } + + store, err := NewFileStore(storeDir) + if err != nil { + t.Fatal(err) + } + + wrongTokenID := "someNonExistingTokenID" + _, err = store.GetUserByTokenID(wrongTokenID) + + assert.Error(t, err, "GetUserByTokenID should throw error if tokenID invalid") +} + func TestFileStore_SavePeerStatus(t *testing.T) { storeDir := t.TempDir() diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 9da748757..2ae71d1a9 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -47,6 +47,7 @@ type MockAccountManager struct { DeletePolicyFunc func(accountID, policyID, userID string) error ListPoliciesFunc func(accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) + GetAccountFromPATFunc func(pat string) (*server.Account, *server.User, error) UpdatePeerMetaFunc func(peerID string, meta server.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(peerID string, sshKey string) error UpdatePeerFunc func(accountID, userID string, peer *server.Peer) (*server.Peer, error) @@ -59,6 +60,8 @@ type MockAccountManager struct { SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) SaveUserFunc func(accountID, userID string, user *server.User) (*server.UserInfo, error) + AddPATToUserFunc func(accountID string, userID string, pat *server.PersonalAccessToken) error + DeletePATFunc func(accountID string, userID string, tokenID string) error GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error @@ -175,6 +178,30 @@ func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*ser return nil, status.Errorf(codes.Unimplemented, "method GetPeerByIP is not implemented") } +// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface +func (am *MockAccountManager) GetAccountFromPAT(pat string) (*server.Account, *server.User, error) { + if am.GetAccountFromPATFunc != nil { + return am.GetAccountFromPATFunc(pat) + } + return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") +} + +// AddPATToUser mock implementation of AddPATToUser from server.AccountManager interface +func (am *MockAccountManager) AddPATToUser(accountID string, userID string, pat *server.PersonalAccessToken) error { + if am.AddPATToUserFunc != nil { + return am.AddPATToUserFunc(accountID, userID, pat) + } + return status.Errorf(codes.Unimplemented, "method AddPATToUser is not implemented") +} + +// DeletePAT mock implementation of DeletePAT from server.AccountManager interface +func (am *MockAccountManager) DeletePAT(accountID string, userID string, tokenID string) error { + if am.DeletePATFunc != nil { + return am.DeletePATFunc(accountID, userID, tokenID) + } + return status.Errorf(codes.Unimplemented, "method DeletePAT is not implemented") +} + // GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, error) { if am.GetNetworkMapFunc != nil { diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index e7ee05dad..7416a9e0b 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -13,8 +13,13 @@ import ( const ( // PATPrefix is the globally used, 4 char prefix for personal access tokens - PATPrefix = "nbp_" - secretLength = 30 + PATPrefix = "nbp_" + // PATSecretLength number of characters used for the secret inside the token + PATSecretLength = 30 + // PATChecksumLength number of characters used for the encoded checksum of the secret inside the token + PATChecksumLength = 6 + // PATLength total number of characters used for the token + PATLength = 40 ) // PersonalAccessToken holds all information about a PAT including a hashed version of it for verification @@ -49,7 +54,7 @@ func CreateNewPAT(description string, expirationInDays int, createdBy string) (* } func generateNewToken() (string, string, error) { - secret, err := b.Random(secretLength) + secret, err := b.Random(PATSecretLength) if err != nil { return "", "", err } diff --git a/management/server/store.go b/management/server/store.go index 02041dfda..daad30eaa 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -6,9 +6,13 @@ type Store interface { GetAccountByUser(userID string) (*Account, error) GetAccountByPeerPubKey(peerKey string) (*Account, error) GetAccountByPeerID(peerID string) (*Account, error) - GetAccountBySetupKey(setupKey string) (*Account, error) //todo use key hash later + GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(domain string) (*Account, error) + GetTokenIDByHashedToken(secret string) (string, error) + GetUserByTokenID(tokenID string) (*User, error) SaveAccount(account *Account) error + DeleteHashedPAT2TokenIDIndex(hashedToken string) error + DeleteTokenID2UserIDIndex(tokenID string) error GetInstallationID() string SaveInstallationID(ID string) error // AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock diff --git a/management/server/testdata/store.json b/management/server/testdata/store.json index 8eddeca23..ecde766c3 100644 --- a/management/server/testdata/store.json +++ b/management/server/testdata/store.json @@ -29,11 +29,23 @@ "Users": { "edafee4e-63fb-11ec-90d6-0242ac120003": { "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", - "Role": "admin" + "Role": "admin", + "PATs": {} }, "f4f6d672-63fb-11ec-90d6-0242ac120003": { "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", - "Role": "user" + "Role": "user", + "PATs": { + "9dj38s35-63fb-11ec-90d6-0242ac120003": { + "ID":"9dj38s35-63fb-11ec-90d6-0242ac120003", + "Description":"some Description", + "HashedToken":"SoMeHaShEdToKeN", + "ExpirationDate":"2023-02-27T00:00:00Z", + "CreatedBy":"user", + "CreatedAt":"2023-01-01T00:00:00Z", + "LastUsed":"2023-02-01T00:00:00Z" + } + } } } } diff --git a/management/server/user.go b/management/server/user.go index a9aaa1b61..c3011c317 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -46,7 +46,7 @@ type User struct { Role UserRole // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user AutoGroups []string - PATs []PersonalAccessToken + PATs map[string]*PersonalAccessToken } // IsAdmin returns true if user is an admin, false otherwise @@ -94,8 +94,12 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) { func (u *User) Copy() *User { autoGroups := make([]string, len(u.AutoGroups)) copy(autoGroups, u.AutoGroups) - pats := make([]PersonalAccessToken, len(u.PATs)) - copy(pats, u.PATs) + pats := make(map[string]*PersonalAccessToken, len(u.PATs)) + for k, v := range u.PATs { + patCopy := new(PersonalAccessToken) + *patCopy = *v + pats[k] = patCopy + } return &User{ Id: u.Id, Role: u.Role, @@ -189,6 +193,59 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, invite *Us } +// AddPATToUser takes the userID and the accountID the user belongs to and assigns a provided PersonalAccessToken to that user +func (am *DefaultAccountManager) AddPATToUser(accountID string, userID string, pat *PersonalAccessToken) error { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return err + } + + user := account.Users[userID] + if user == nil { + return status.Errorf(status.NotFound, "user not found") + } + + user.PATs[pat.ID] = pat + + return am.Store.SaveAccount(account) +} + +// DeletePAT deletes a specific PAT from a user +func (am *DefaultAccountManager) DeletePAT(accountID string, userID string, tokenID string) error { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return err + } + + user := account.Users[userID] + if user == nil { + return status.Errorf(status.NotFound, "user not found") + } + + pat := user.PATs["tokenID"] + if pat == nil { + return status.Errorf(status.NotFound, "PAT not found") + } + + err = am.Store.DeleteTokenID2UserIDIndex(pat.ID) + if err != nil { + return err + } + err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken) + if err != nil { + return err + } + delete(user.PATs, tokenID) + + return am.Store.SaveAccount(account) +} + // SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error. // Only User.AutoGroups field is allowed to be updated for now. func (am *DefaultAccountManager) SaveUser(accountID, userID string, update *User) (*UserInfo, error) { diff --git a/management/server/user_test.go b/management/server/user_test.go new file mode 100644 index 000000000..20f2ca4f1 --- /dev/null +++ b/management/server/user_test.go @@ -0,0 +1,84 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + mockAccountID = "accountID" + mockUserID = "userID" + mockTokenID = "tokenID" + mockToken = "SoMeHaShEdToKeN" +) + +func TestUser_AddPATToUser(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + pat := PersonalAccessToken{ + ID: mockTokenID, + HashedToken: mockToken, + } + + err = am.AddPATToUser(mockAccountID, mockUserID, &pat) + if err != nil { + t.Fatalf("Error when adding PAT to user: %s", err) + } + + fileStore := am.Store.(*FileStore) + tokenID := fileStore.HashedPAT2TokenID[mockToken[:]] + + if tokenID == "" { + t.Fatal("GetTokenIDByHashedToken failed after adding PAT") + } + + assert.Equal(t, mockTokenID, tokenID) + + userID := fileStore.TokenID2UserID[tokenID] + if userID == "" { + t.Fatal("GetUserByTokenId failed after adding PAT") + } + assert.Equal(t, mockUserID, userID) +} + +func TestUser_DeletePAT(t *testing.T) { + store := newStore(t) + account := newAccountWithId(mockAccountID, mockUserID, "") + account.Users[mockUserID] = &User{ + Id: mockUserID, + PATs: map[string]*PersonalAccessToken{ + mockTokenID: { + ID: mockTokenID, + HashedToken: mockToken, + }, + }, + } + err := store.SaveAccount(account) + if err != nil { + t.Fatalf("Error when saving account: %s", err) + } + + am := DefaultAccountManager{ + Store: store, + } + + err = am.DeletePAT(mockAccountID, mockUserID, mockTokenID) + if err != nil { + t.Fatalf("Error when adding PAT to user: %s", err) + } + + assert.Nil(t, store.Accounts[mockAccountID].Users[mockUserID].PATs[mockTokenID]) + assert.Empty(t, store.HashedPAT2TokenID[mockToken]) + assert.Empty(t, store.TokenID2UserID[mockTokenID]) +}