mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 16:56:39 +00:00
Merge branch 'main' into feature/port-forwarding
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
@@ -414,24 +415,16 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr
|
||||
}
|
||||
|
||||
// SaveUsers saves the given list of users to the database.
|
||||
// It updates existing users if a conflict occurs.
|
||||
func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error {
|
||||
usersToSave := make([]types.User, 0, len(users))
|
||||
for _, user := range users {
|
||||
user.AccountID = accountID
|
||||
for id, pat := range user.PATs {
|
||||
pat.ID = id
|
||||
user.PATsG = append(user.PATsG, *pat)
|
||||
}
|
||||
usersToSave = append(usersToSave, *user)
|
||||
}
|
||||
err := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
Create(&usersToSave).Error
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to save users to store: %v", err)
|
||||
func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error {
|
||||
if len(users) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&users)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save users to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -439,7 +432,8 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) err
|
||||
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
|
||||
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save user to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -450,7 +444,7 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength,
|
||||
return nil
|
||||
}
|
||||
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&groups)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
|
||||
}
|
||||
@@ -526,30 +520,17 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
|
||||
return token.ID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) {
|
||||
var token types.PersonalAccessToken
|
||||
result := s.db.First(&token, idQueryCondition, tokenID)
|
||||
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
|
||||
var user types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
|
||||
Where("personal_access_tokens.id = ?", patID).First(&user)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
return nil, status.NewPATNotFoundError(patID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
if token.UserID == "" {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
var user types.User
|
||||
result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
|
||||
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG))
|
||||
for _, pat := range user.PATsG {
|
||||
user.PATs[pat.ID] = pat.Copy()
|
||||
log.WithContext(ctx).Errorf("failed to get token user from the store: %s", result.Error)
|
||||
return nil, status.NewGetUserFromStoreError()
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
@@ -557,8 +538,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
||||
var user types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Preload(clause.Associations).First(&user, idQueryCondition, userID)
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
@@ -569,6 +549,25 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error {
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&types.PersonalAccessToken{}, "user_id = ?", userID)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
return tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error
|
||||
})
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete user from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
|
||||
var users []*types.User
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
||||
@@ -899,6 +898,20 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
||||
return accountSettings.Settings, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
|
||||
var createdBy string
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||
Select("created_by").First(&createdBy, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
return createdBy, nil
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
var user types.User
|
||||
@@ -956,7 +969,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
|
||||
}
|
||||
|
||||
file := filepath.Join(dataDir, storeStr)
|
||||
db, err := gorm.Open(sqlite.Open(file), getGormConfig(SqliteStoreEngine))
|
||||
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -966,7 +979,7 @@ func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMe
|
||||
|
||||
// NewPostgresqlStore creates a new Postgres store.
|
||||
func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
db, err := gorm.Open(postgres.Open(dsn), getGormConfig(PostgresStoreEngine))
|
||||
db, err := gorm.Open(postgres.Open(dsn), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -976,7 +989,7 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe
|
||||
|
||||
// NewMysqlStore creates a new MySQL store.
|
||||
func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) {
|
||||
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), getGormConfig(MysqlStoreEngine))
|
||||
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -984,15 +997,10 @@ func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics
|
||||
return NewSqlStore(ctx, db, MysqlStoreEngine, metrics)
|
||||
}
|
||||
|
||||
func getGormConfig(engine Engine) *gorm.Config {
|
||||
prepStmt := true
|
||||
if engine == SqliteStoreEngine {
|
||||
prepStmt = false
|
||||
}
|
||||
func getGormConfig() *gorm.Config {
|
||||
return &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
CreateBatchSize: 400,
|
||||
PrepareStmt: prepStmt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2061,3 +2069,94 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
|
||||
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
|
||||
var pat types.PersonalAccessToken
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&pat, "hashed_token = ?", hashedToken)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewPATNotFoundError(hashedToken)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get pat by hash from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get pat by hash from store")
|
||||
}
|
||||
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
// GetPATByID retrieves a personal access token by its ID and user ID.
|
||||
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
|
||||
var pat types.PersonalAccessToken
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewPATNotFoundError(patID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get pat from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get pat from store")
|
||||
}
|
||||
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
// GetUserPATs retrieves personal access tokens for a user.
|
||||
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
|
||||
var pats []*types.PersonalAccessToken
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&pats, "user_id = ?", userID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get user pat's from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get user pat's from store")
|
||||
}
|
||||
|
||||
return pats, nil
|
||||
}
|
||||
|
||||
// MarkPATUsed marks a personal access token as used.
|
||||
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
|
||||
patCopy := types.PersonalAccessToken{
|
||||
LastUsed: util.ToPtr(time.Now().UTC()),
|
||||
}
|
||||
|
||||
fieldsToUpdate := []string{"last_used"}
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Select(fieldsToUpdate).
|
||||
Where(idQueryCondition, patID).Updates(&patCopy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to mark pat as used: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to mark pat as used")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewPATNotFoundError(patID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SavePAT saves a personal access token to the database.
|
||||
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save pat to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePAT deletes a personal access token from the database.
|
||||
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete pat from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewPATNotFoundError(patID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user