diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 23b253c5a..c9fc51c9e 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1533,12 +1533,58 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking // GetAccountRoutes retrieves network routes for an account. func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { - return getRecords[*route.Route](s.db, lockStrength, accountID) + var routes []*route.Route + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&routes, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get routes from store") + } + + return routes, nil } // GetRouteByID retrieves a route by its ID and account ID. -func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { - return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID) +func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) { + var route *route.Route + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&route, accountAndIDQueryCondition, accountID, routeID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewRouteNotFoundError(routeID) + } + log.WithContext(ctx).Errorf("failed to get route from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get route from store") + } + + return route, nil +} + +// SaveRoute saves a route to the database. +func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save route to the store: %s", err) + return status.Errorf(status.Internal, "failed to save route to store") + } + + return nil +} + +// DeleteRoute deletes a route from the database. +func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete route from store") + } + + if result.RowsAffected == 0 { + return status.NewRouteNotFoundError(routeID) + } + + return nil } // GetAccountSetupKeys retrieves setup keys for an account. @@ -1649,39 +1695,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki return nil } -// getRecords retrieves records from the database based on the account ID. -func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { - var record []T - - result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID) - if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) - } - - return record, nil -} - -// getRecordByID retrieves a record by its ID and account ID from the database. -func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) { - var record T - - result := db.Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&record, accountAndIDQueryCondition, accountID, recordID) - if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "%s not found", recordType) - } - return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err) - } - return &record, nil -} - // SaveDNSSettings saves the DNS settings to the store. func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). diff --git a/management/server/status/error.go b/management/server/status/error.go index 045469306..9a5d06530 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -174,3 +174,7 @@ func NewOwnerDeletePermissionError() error { func NewPATNotFoundError(patID string) error { return Errorf(NotFound, "PAT: %s not found", patID) } + +func NewRouteNotFoundError(routeID string) error { + return Errorf(NotFound, "route: %s not found", routeID) +} diff --git a/management/server/store.go b/management/server/store.go index 01a4955c1..440caad75 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -128,7 +128,9 @@ type Store interface { DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) - GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) + GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error) + SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error + DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)