diff --git a/management/server/account.go b/management/server/account.go index 72c866289..c2ca640bd 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -88,7 +88,7 @@ type AccountManager interface { GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error - GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) + GetAccountInfoFromPAT(ctx context.Context, token string) (*User, *PersonalAccessToken, string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error GetUserByID(ctx context.Context, id string) (*User, error) @@ -1869,52 +1869,59 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin return am.Store.GetAccount(ctx, accountID) } -// GetAccountFromPAT returns Account and User associated with a personal access token -func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { +// GetAccountInfoFromPAT retrieves user, personal access token, domain, and category details from a personal access token. +func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (user *User, pat *PersonalAccessToken, domain string, category string, err error) { + user, pat, err = am.extractPATFromToken(ctx, token) + if err != nil { + return nil, nil, "", "", err + } + + domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, user.AccountID) + if err != nil { + return nil, nil, "", "", err + } + + return user, pat, domain, category, nil +} + +// extractPATFromToken validates the token structure and retrieves associated User and PAT. +func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) { if len(token) != PATLength { - return nil, nil, nil, fmt.Errorf("token has wrong length") + return nil, nil, fmt.Errorf("token has incorrect length") } prefix := token[:len(PATPrefix)] if prefix != PATPrefix { - return nil, nil, nil, fmt.Errorf("token has wrong prefix") + return nil, nil, fmt.Errorf("token has incorrect prefix") } + secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] verificationChecksum, err := base62.Decode(encodedChecksum) if err != nil { - return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) + return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) } secretChecksum := crc32.ChecksumIEEE([]byte(secret)) if secretChecksum != verificationChecksum { - return nil, nil, nil, fmt.Errorf("token checksum does not match") + return nil, nil, fmt.Errorf("token checksum does not match") } hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken) + + pat, err := am.Store.GetPATByHashedToken(ctx, LockingStrengthShare, encodedHashedToken) if err != nil { - return nil, nil, nil, err + return nil, nil, err } - user, err := am.Store.GetUserByTokenID(ctx, tokenID) + user, err := am.Store.GetUserByPATID(ctx, LockingStrengthShare, pat.ID) if err != nil { - return nil, nil, nil, err + return nil, nil, err } - account, err := am.Store.GetAccountByUser(ctx, user.Id) - if err != nil { - return nil, nil, nil, err - } - - pat := user.PATs[tokenID] - if pat == nil { - return nil, nil, nil, fmt.Errorf("personal access token not found") - } - - return account, user, pat, nil + return user, pat, nil } // GetAccountByID returns an account associated with this account ID. diff --git a/management/server/http/handler.go b/management/server/http/handler.go index c3928bff6..bb6d00209 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa ) authMiddleware := middleware.NewAuthMiddleware( - accountManager.GetAccountFromPAT, + accountManager.GetAccountInfoFromPAT, jwtValidator.ValidateAndParse, accountManager.MarkPATUsed, accountManager.CheckUserAccessByJWTGroups, diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index b25aad99c..c1502e479 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -9,9 +9,9 @@ import ( "time" "github.com/golang-jwt/jwt" + "github.com/netbirdio/netbird/management/server" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" @@ -19,8 +19,8 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -// GetAccountFromPATFunc function -type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) +// GetAccountInfoFromPATFunc function +type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) // ValidateAndParseTokenFunc function type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) @@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - getAccountFromPAT GetAccountFromPATFunc + getAccountInfoFromPAT GetAccountInfoFromPATFunc validateAndParseToken ValidateAndParseTokenFunc markPATUsed MarkPATUsedFunc checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc @@ -47,7 +47,7 @@ const ( ) // NewAuthMiddleware instance constructor -func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, +func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc, markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor, audience string, userIdClaim string) *AuthMiddleware { if userIdClaim == "" { @@ -55,7 +55,7 @@ func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParse } return &AuthMiddleware{ - getAccountFromPAT: getAccountFromPAT, + getAccountInfoFromPAT: getAccountInfoFromPAT, validateAndParseToken: validateAndParseToken, markPATUsed: markPATUsed, checkUserAccessByJWTGroups: checkUserAccessByJWTGroups, @@ -116,7 +116,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(w http.ResponseWriter, r *http.Requ // If an error occurs, call the error handler and return an error if err != nil { - return fmt.Errorf("Error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } validatedToken, err := m.validateAndParseToken(r.Context(), token) @@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j // CheckPATFromRequest checks if the PAT is valid func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error { token, err := getTokenFromPATRequest(auth) - - // If an error occurs, call the error handler and return an error if err != nil { - return fmt.Errorf("Error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } - account, user, pat, err := m.getAccountFromPAT(r.Context(), token) + user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token) if err != nil { return fmt.Errorf("invalid Token: %w", err) } @@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ claimMaps := jwt.MapClaims{} claimMaps[m.userIDClaim] = user.Id - claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id - claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain - claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID + claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain + claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint // Update the current request with the new context information. diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 2f91a0478..8c8496b50 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -55,7 +55,7 @@ type MockAccountManager struct { DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error) MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) @@ -235,12 +235,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { - if am.GetAccountFromPATFunc != nil { - return am.GetAccountFromPATFunc(ctx, pat) +// GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface +func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error) { + if am.GetAccountInfoFromPATFunc != nil { + return am.GetAccountInfoFromPATFunc(ctx, token) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") + return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetAccountInfoFromPAT is not implemented") } // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface diff --git a/management/server/sql_store.go b/management/server/sql_store.go index a5316d72d..c5acb3c5b 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -475,49 +475,6 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* return s.GetAccount(ctx, key.AccountID) } -func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { - var token PersonalAccessToken - result := s.db.First(&token, "hashed_token = ?", hashedToken) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return "", status.Errorf(status.NotFound, "account not found: index lookup failed") - } - log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return "", status.NewGetAccountFromStoreError(result.Error) - } - - return token.ID, nil -} - -func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) { - var token PersonalAccessToken - result := s.db.First(&token, idQueryCondition, tokenID) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") - } - log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error) - return nil, status.NewGetAccountFromStoreError(result.Error) - } - - if token.UserID == "" { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") - } - - var user User - result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID) - if result.Error != nil { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") - } - - user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG)) - for _, pat := range user.PATsG { - user.PATs[pat.ID] = pat.Copy() - } - - return &user, nil -} - func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { var user User result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). @@ -526,6 +483,23 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) } + log.WithContext(ctx).Errorf("failed to get user from the store: %s", result.Error) + return nil, status.NewGetUserFromStoreError() + } + + return &user, nil +} + +func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) { + var user User + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id"). + Where("personal_access_tokens.id = ?", patID).First(&user) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewPATNotFoundError() + } + log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error) return nil, status.NewGetUserFromStoreError() } @@ -1635,6 +1609,21 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki return nil } +// GetPATByHashedToken returns a PersonalAccessToken by its hashed token. +func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) { + var pat PersonalAccessToken + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewPATNotFoundError() + } + log.WithContext(ctx).Errorf("failed to get pat from the store: %s", result.Error) + return nil, status.NewGetPATFromStoreError() + } + + return &pat, nil +} + // GetPATByID retrieves a personal access token by its ID and user ID. func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*PersonalAccessToken, error) { var pat PersonalAccessToken @@ -1642,10 +1631,10 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, First(&pat, "id = ? AND user_id = ?", patID, userID) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "PAT not found") + return nil, status.NewPATNotFoundError() } log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err) - return nil, status.Errorf(status.Internal, "failed to get PAT from store") + return nil, status.NewGetPATFromStoreError() } return &pat, nil diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index a8e6576ed..b04060583 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -572,11 +572,11 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) + pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed) require.NoError(t, err) - require.Equal(t, id, token) + require.Equal(t, id, pat.ID) - _, err = store.GetTokenIDByHashedToken(context.Background(), "non-existing-hash") + _, err = store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "non-existing-hash") require.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -595,11 +595,11 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - user, err := store.GetUserByTokenID(context.Background(), id) + user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id) require.NoError(t, err) require.Equal(t, id, user.PATs[id].ID) - _, err = store.GetUserByTokenID(context.Background(), "non-existing-id") + _, err = store.GetUserByPATID(context.Background(), LockingStrengthShare, "non-existing-id") require.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -967,9 +967,9 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { hashed := "SoMeHaShEdToKeN" id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - token, err := store.GetTokenIDByHashedToken(context.Background(), hashed) + pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, hashed) require.NoError(t, err) - require.Equal(t, id, token) + require.Equal(t, id, pat.ID) } func TestPostgresql_GetUserByTokenID(t *testing.T) { @@ -984,7 +984,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - user, err := store.GetUserByTokenID(context.Background(), id) + user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id) require.NoError(t, err) require.Equal(t, id, user.PATs[id].ID) } diff --git a/management/server/status/error.go b/management/server/status/error.go index 0dd302dfa..cbdce69db 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -146,6 +146,10 @@ func NewPATNotFoundError() error { return Errorf(NotFound, "PAT not found") } +func NewGetPATFromStoreError() error { + return Errorf(Internal, "issue getting pat from store") +} + func NewUnauthorizedToViewPATsError() error { return Errorf(PermissionDenied, "only users with admin power can view PATs") } diff --git a/management/server/store.go b/management/server/store.go index ddaf37e17..0e4cda982 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -63,13 +63,12 @@ type Store interface { DeleteAccount(ctx context.Context, account *Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error - GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) + GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error - GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) @@ -125,6 +124,7 @@ type Store interface { GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) + GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error diff --git a/management/server/user_test.go b/management/server/user_test.go index d4f560a54..91a9d6245 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -55,25 +55,25 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) + newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } - assert.Equal(t, pat.CreatedBy, mockUserID) + assert.Equal(t, newPAT.CreatedBy, mockUserID) - tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken) + pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken) if err != nil { t.Fatalf("Error when getting token ID by hashed token: %s", err) } - if tokenID == "" { + if pat.ID == "" { t.Fatal("GetTokenIDByHashedToken failed after adding PAT") } - assert.Equal(t, pat.ID, tokenID) + assert.Equal(t, newPAT.ID, pat.ID) - user, err := am.Store.GetUserByTokenID(context.Background(), tokenID) + user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID) if err != nil { t.Fatalf("Error when getting user by token ID: %s", err) }