Merge branch 'netbirdio:main' into main

This commit is contained in:
İsmail
2024-11-12 22:26:50 +03:00
committed by GitHub
25 changed files with 341 additions and 179 deletions

View File

@@ -503,9 +503,10 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
return nil, status.NewSetupKeyNotFoundError(setupKey)
}
return nil, status.NewSetupKeyNotFoundError(result.Error)
log.WithContext(ctx).Errorf("failed to get account by setup key from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get account by setup key from store")
}
if key.AccountID == "" {
@@ -586,15 +587,15 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us
return users, nil
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID)
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting groups from store")
log.WithContext(ctx).Errorf("failed to get account groups from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get account groups from the store")
}
return groups, nil
@@ -775,9 +776,10 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
result := s.db.Model(&SetupKey{}).Select("account_id").Where(GetKeyQueryCondition(s), setupKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
return "", status.NewSetupKeyNotFoundError(setupKey)
}
return "", status.NewSetupKeyNotFoundError(result.Error)
log.WithContext(ctx).Errorf("failed to get account ID by setup key from store: %v", result.Error)
return "", status.Errorf(status.Internal, "failed to get account ID by setup key from store")
}
if accountID == "" {
@@ -1049,9 +1051,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
return nil, status.NewSetupKeyNotFoundError(key)
}
return nil, status.NewSetupKeyNotFoundError(result.Error)
log.WithContext(ctx).Errorf("failed to get setup key by secret from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get setup key by secret from store")
}
return &setupKey, nil
}
@@ -1069,7 +1072,7 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "setup key not found")
return status.NewSetupKeyNotFoundError(setupKeyID)
}
return nil
@@ -1247,6 +1250,23 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
return &group, nil
}
// GetGroupsByIDs retrieves groups by their IDs and account ID.
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) {
var groups []*nbgroup.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
}
groupsMap := make(map[string]*nbgroup.Group)
for _, group := range groups {
groupsMap[group.ID] = group
}
return groupsMap, nil
}
// SaveGroup saves a group to the store.
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
@@ -1288,12 +1308,57 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db, lockStrength, accountID)
var setupKeys []*SetupKey
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&setupKeys, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get setup keys from store")
}
return setupKeys, nil
}
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
return getRecordByID[SetupKey](s.db, lockStrength, setupKeyID, accountID)
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
var setupKey *SetupKey
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewSetupKeyNotFoundError(setupKeyID)
}
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
}
return setupKey, nil
}
// SaveSetupKey saves a setup key to the database.
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save setup key to store")
}
return nil
}
// DeleteSetupKey deletes a setup key from the database.
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete setup key from store")
}
if result.RowsAffected == 0 {
return status.NewSetupKeyNotFoundError(keyID)
}
return nil
}
// GetAccountNameServerGroups retrieves name server groups for an account.
@@ -1306,10 +1371,6 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID)
}
func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
return deleteRecordByID[SetupKey](s.db, LockingStrengthUpdate, keyID, accountID)
}
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T
@@ -1342,21 +1403,3 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
}
return &record, nil
}
// deleteRecordByID deletes a record by its ID and account ID from the database.
func deleteRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) error {
var record T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return status.Errorf(status.Internal, "failed to delete %s from store: %v", recordType, err)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "record not found")
}
return nil
}