mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management] Store connected proxies in DB (#5472)
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
This commit is contained in:
@@ -28,9 +28,10 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -131,8 +132,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -2075,7 +2076,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
|
||||
return checks, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
|
||||
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
||||
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key
|
||||
@@ -2090,8 +2091,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
return nil, err
|
||||
}
|
||||
|
||||
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) {
|
||||
var s reverseproxy.Service
|
||||
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
|
||||
var s rpservice.Service
|
||||
var auth []byte
|
||||
var createdAt, certIssuedAt sql.NullTime
|
||||
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
||||
@@ -2121,7 +2122,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
}
|
||||
}
|
||||
|
||||
s.Meta = reverseproxy.ServiceMeta{}
|
||||
s.Meta = rpservice.Meta{}
|
||||
if createdAt.Valid {
|
||||
s.Meta.CreatedAt = createdAt.Time
|
||||
}
|
||||
@@ -2142,7 +2143,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
s.SessionPublicKey = sessionPublicKey.String
|
||||
}
|
||||
|
||||
s.Targets = []*reverseproxy.Target{}
|
||||
s.Targets = []*rpservice.Target{}
|
||||
return &s, nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -2154,7 +2155,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
}
|
||||
|
||||
serviceIDs := make([]string, len(services))
|
||||
serviceMap := make(map[string]*reverseproxy.Service)
|
||||
serviceMap := make(map[string]*rpservice.Service)
|
||||
for i, s := range services {
|
||||
serviceIDs[i] = s.ID
|
||||
serviceMap[s.ID] = s
|
||||
@@ -2165,8 +2166,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) {
|
||||
var t reverseproxy.Target
|
||||
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) {
|
||||
var t rpservice.Target
|
||||
var path sql.NullString
|
||||
err := row.Scan(
|
||||
&t.ID,
|
||||
@@ -4852,7 +4853,7 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error {
|
||||
serviceCopy := service.Copy()
|
||||
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt service data: %w", err)
|
||||
@@ -4866,16 +4867,19 @@ func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Serv
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error {
|
||||
serviceCopy := service.Copy()
|
||||
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt service data: %w", err)
|
||||
}
|
||||
|
||||
// Create target type instance outside transaction to avoid variable shadowing
|
||||
targetType := &rpservice.Target{}
|
||||
|
||||
// Use a transaction to ensure atomic updates of the service and its targets
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Delete existing targets
|
||||
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
|
||||
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4896,7 +4900,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Serv
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
|
||||
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID)
|
||||
result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete service from store")
|
||||
@@ -4910,7 +4914,7 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error {
|
||||
result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
|
||||
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete target from store")
|
||||
@@ -4924,7 +4928,7 @@ func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error {
|
||||
result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
|
||||
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete targets from store")
|
||||
@@ -4934,8 +4938,8 @@ func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, s
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID retrieves all targets for a given service
|
||||
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*reverseproxy.Target, error) {
|
||||
var targets []*reverseproxy.Target
|
||||
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) {
|
||||
var targets []*rpservice.Target
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
@@ -4949,13 +4953,13 @@ func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength Locki
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var service *reverseproxy.Service
|
||||
var service *rpservice.Service
|
||||
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -4973,30 +4977,8 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
result := tx.Find(&serviceList, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get services from store")
|
||||
}
|
||||
|
||||
for _, service := range serviceList {
|
||||
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt service data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return serviceList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||
var service *reverseproxy.Service
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
|
||||
var service *rpservice.Service
|
||||
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -5014,13 +4996,13 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
var serviceList []*rpservice.Service
|
||||
result := tx.Find(&serviceList)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
@@ -5036,13 +5018,13 @@ func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength
|
||||
return serviceList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
var serviceList []*rpservice.Service
|
||||
result := tx.Find(&serviceList, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
@@ -5270,13 +5252,13 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
|
||||
return query
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
|
||||
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var target *reverseproxy.Target
|
||||
var target *rpservice.Target
|
||||
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -5289,3 +5271,65 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
// SaveProxy saves or updates a proxy in the database
|
||||
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
result := s.db.WithContext(ctx).Save(p)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save proxy")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND status = ?", proxyID, "connected").
|
||||
Update("last_seen", time.Now())
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update proxy heartbeat")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses")
|
||||
}
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
|
||||
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
cutoffTime := time.Now().Add(-inactivityDuration)
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Where("last_seen < ?", cutoffTime).
|
||||
Delete(&proxy.Proxy{})
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to cleanup stale proxies")
|
||||
}
|
||||
|
||||
if result.RowsAffected > 0 {
|
||||
log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user