From 0cb02bd906e4ea2724b01dff094f2837d0489289 Mon Sep 17 00:00:00 2001 From: pascal Date: Tue, 10 Feb 2026 17:12:34 +0100 Subject: [PATCH] fix path handling + extract targets to separate table + guard resource/peer deletion --- .../modules/reverseproxy/interface.go | 1 + .../modules/reverseproxy/manager/manager.go | 13 ++ .../modules/reverseproxy/reverseproxy.go | 40 +++-- management/internals/server/modules.go | 5 + management/server/account.go | 4 + management/server/account/manager.go | 2 + .../server/networks/resources/manager.go | 8 + management/server/peer.go | 8 + management/server/store/sql_store.go | 143 +++++++++++++++--- management/server/store/store.go | 1 + shared/management/status/error.go | 8 + 11 files changed, 197 insertions(+), 36 deletions(-) diff --git a/management/internals/modules/reverseproxy/interface.go b/management/internals/modules/reverseproxy/interface.go index 93ed19e0f..83064e2c6 100644 --- a/management/internals/modules/reverseproxy/interface.go +++ b/management/internals/modules/reverseproxy/interface.go @@ -17,4 +17,5 @@ type Manager interface { GetGlobalReverseProxies(ctx context.Context) ([]*ReverseProxy, error) GetProxyByID(ctx context.Context, accountID, reverseProxyID string) (*ReverseProxy, error) GetAccountReverseProxies(ctx context.Context, accountID string) ([]*ReverseProxy, error) + GetProxyIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) } diff --git a/management/internals/modules/reverseproxy/manager/manager.go b/management/internals/modules/reverseproxy/manager/manager.go index 86930ae72..f504fb055 100644 --- a/management/internals/modules/reverseproxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/manager/manager.go @@ -472,3 +472,16 @@ func (m *managerImpl) GetAccountReverseProxies(ctx context.Context, accountID st return proxies, nil } + +func (m *managerImpl) GetProxyIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { + target, err := m.store.GetReverseProxyTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) + if err != nil { + return "", fmt.Errorf("failed to get reverse proxy target by resource ID: %w", err) + } + + if target == nil { + return "", nil + } + + return target.ReverseProxyID, nil +} diff --git a/management/internals/modules/reverseproxy/reverseproxy.go b/management/internals/modules/reverseproxy/reverseproxy.go index d2672e715..502de0d4a 100644 --- a/management/internals/modules/reverseproxy/reverseproxy.go +++ b/management/internals/modules/reverseproxy/reverseproxy.go @@ -42,13 +42,16 @@ const ( ) type Target struct { - Path *string `json:"path,omitempty"` - Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored - Port int `json:"port"` - Protocol string `json:"protocol"` - TargetId string `json:"target_id"` - TargetType string `json:"target_type"` - Enabled bool `json:"enabled"` + ID uint `gorm:"primaryKey" json:"-"` + AccountID string `gorm:"index:idx_target_account;not null" json:"-"` + ReverseProxyID string `gorm:"index:idx_reverse_proxy_targets;not null" json:"-"` + Path *string `json:"path,omitempty"` + Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored + Port int `gorm:"index:idx_target_port" json:"port"` + Protocol string `gorm:"index:idx_target_protocol" json:"protocol"` + TargetId string `gorm:"index:idx_target_id" json:"target_id"` + TargetType string `gorm:"index:idx_target_type" json:"target_type"` + Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"` } type PasswordAuthConfig struct { @@ -91,7 +94,7 @@ type ReverseProxy struct { Name string Domain string `gorm:"index"` ProxyCluster string `gorm:"index"` - Targets []*Target `gorm:"serializer:json"` + Targets []*Target `gorm:"foreignKey:ReverseProxyID;constraint:OnDelete:CASCADE"` Enabled bool PassHostHeader bool RewriteRedirects bool @@ -102,6 +105,10 @@ type ReverseProxy struct { } func NewReverseProxy(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *ReverseProxy { + for _, target := range targets { + target.AccountID = accountID + } + rp := &ReverseProxy{ AccountID: accountID, Name: name, @@ -198,23 +205,22 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oid continue } - path := "/" - if target.Path != nil { - path = *target.Path - } - // TODO: Make path prefix stripping configurable per-target. // Currently the matching prefix is baked into the target URL path, // so the proxy strips-then-re-adds it (effectively a no-op). targetURL := url.URL{ Scheme: target.Protocol, Host: target.Host, - Path: path, + Path: "/", // TODO: support service path } if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) { targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port)) } + path := "/" + if target.Path != nil { + path = *target.Path + } pathMappings = append(pathMappings, &proto.PathMapping{ Path: path, Target: targetURL.String(), @@ -279,6 +285,7 @@ func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID st targets := make([]*Target, 0, len(req.Targets)) for _, apiTarget := range req.Targets { target := &Target{ + AccountID: accountID, Path: apiTarget.Path, Port: apiTarget.Port, Protocol: string(apiTarget.Protocol), @@ -369,7 +376,10 @@ func (r *ReverseProxy) EventMeta() map[string]any { func (r *ReverseProxy) Copy() *ReverseProxy { targets := make([]*Target, len(r.Targets)) - copy(targets, r.Targets) + for i, target := range r.Targets { + targetCopy := *target + targets[i] = &targetCopy + } return &ReverseProxy{ ID: r.ID, diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 1ff409289..f1360dbce 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -101,6 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager { if err != nil { log.Fatalf("failed to create account manager: %v", err) } + + s.AfterInit(func(s *BaseServer) { + accountManager.SetReverseProxyManager(s.ReverseProxyManager()) + }) + return accountManager }) } diff --git a/management/server/account.go b/management/server/account.go index a0f68452a..30b129856 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -115,6 +115,10 @@ type DefaultAccountManager struct { var _ account.Manager = (*DefaultAccountManager)(nil) +func (am *DefaultAccountManager) SetReverseProxyManager(reverseProxyManager reverseproxy.Manager) { + am.reverseProxyManager = reverseProxyManager +} + func isUniqueConstraintError(err error) bool { switch { case strings.Contains(err.Error(), "(SQLSTATE 23505)"), diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 1d25b0af7..ae92d679e 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -6,6 +6,7 @@ import ( "net/netip" "time" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/shared/auth" nbdns "github.com/netbirdio/netbird/dns" @@ -139,4 +140,5 @@ type Manager interface { CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) + SetReverseProxyManager(reverseProxyManager reverseproxy.Manager) } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 7497aa26e..decc61801 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -322,6 +322,14 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net return status.NewPermissionDeniedError() } + proxyID, err := m.reverseProxyManager.GetProxyIDByTargetID(ctx, accountID, resourceID) + if err != nil { + return fmt.Errorf("failed to check if resource is used by reverse proxy: %w", err) + } + if proxyID != "" { + return status.NewResourceInUseError(resourceID, proxyID) + } + var events []func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID) diff --git a/management/server/peer.go b/management/server/peer.go index ca9961b37..1297662e8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -489,6 +489,14 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer var settings *types.Settings var eventsToStore []func() + proxyID, err := am.reverseProxyManager.GetProxyIDByTargetID(ctx, accountID, peerID) + if err != nil { + return fmt.Errorf("failed to check if resource is used by reverse proxy: %w", err) + } + if proxyID != "" { + return status.NewPeerInUseError(peerID, proxyID) + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) if err != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index cf7067a7c..048cd8962 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -131,7 +131,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, - &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.ReverseProxy{}, &domain.Domain{}, + &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.ReverseProxy{}, &reverseproxy.Target{}, &domain.Domain{}, &accesslogs.AccessLogEntry{}, ) if err != nil { @@ -2064,45 +2064,51 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p } func (s *SqlStore) getProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) { - const query = `SELECT id, account_id, name, domain, targets, enabled, auth, - meta_created_at, meta_certificate_issued_at, meta_status + const proxyQuery = `SELECT id, account_id, name, domain, enabled, auth, + meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster, + pass_host_header, rewrite_redirects, session_private_key, session_public_key FROM reverse_proxies WHERE account_id = $1` - rows, err := s.pool.Query(ctx, query, accountID) + + const targetsQuery = `SELECT id, account_id, reverse_proxy_id, path, host, port, protocol, + target_id, target_type, enabled + FROM targets WHERE reverse_proxy_id = ANY($1)` + + proxyRows, err := s.pool.Query(ctx, proxyQuery, accountID) if err != nil { return nil, err } - proxies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*reverseproxy.ReverseProxy, error) { + + proxies, err := pgx.CollectRows(proxyRows, func(row pgx.CollectableRow) (*reverseproxy.ReverseProxy, error) { var p reverseproxy.ReverseProxy var auth []byte - var targets []byte var createdAt, certIssuedAt sql.NullTime - var status sql.NullString + var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString err := row.Scan( &p.ID, &p.AccountID, &p.Name, &p.Domain, - &targets, &p.Enabled, &auth, &createdAt, &certIssuedAt, &status, + &proxyCluster, + &p.PassHostHeader, + &p.RewriteRedirects, + &sessionPrivateKey, + &sessionPublicKey, ) if err != nil { return nil, err } - // Unmarshal JSON fields + if auth != nil { if err := json.Unmarshal(auth, &p.Auth); err != nil { return nil, err } } - if targets != nil { - if err := json.Unmarshal(targets, &p.Targets); err != nil { - return nil, err - } - } + p.Meta = reverseproxy.ReverseProxyMeta{} if createdAt.Valid { p.Meta.CreatedAt = createdAt.Time @@ -2113,12 +2119,72 @@ func (s *SqlStore) getProxies(ctx context.Context, accountID string) ([]*reverse if status.Valid { p.Meta.Status = status.String } + if proxyCluster.Valid { + p.ProxyCluster = proxyCluster.String + } + if sessionPrivateKey.Valid { + p.SessionPrivateKey = sessionPrivateKey.String + } + if sessionPublicKey.Valid { + p.SessionPublicKey = sessionPublicKey.String + } + p.Targets = []*reverseproxy.Target{} return &p, nil }) if err != nil { return nil, err } + + if len(proxies) == 0 { + return proxies, nil + } + + proxyIDs := make([]string, len(proxies)) + proxyMap := make(map[string]*reverseproxy.ReverseProxy) + for i, p := range proxies { + proxyIDs[i] = p.ID + proxyMap[p.ID] = p + } + + targetRows, err := s.pool.Query(ctx, targetsQuery, proxyIDs) + if err != nil { + return nil, err + } + + targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) { + var t reverseproxy.Target + var path sql.NullString + err := row.Scan( + &t.ID, + &t.AccountID, + &t.ReverseProxyID, + &path, + &t.Host, + &t.Port, + &t.Protocol, + &t.TargetId, + &t.TargetType, + &t.Enabled, + ) + if err != nil { + return nil, err + } + if path.Valid { + t.Path = &path.String + } + return &t, nil + }) + if err != nil { + return nil, err + } + + for _, target := range targets { + if proxy, ok := proxyMap[target.ReverseProxyID]; ok { + proxy.Targets = append(proxy.Targets, target) + } + } + return proxies, nil } @@ -4778,9 +4844,24 @@ func (s *SqlStore) UpdateReverseProxy(ctx context.Context, proxy *reverseproxy.R if err := proxyCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { return fmt.Errorf("encrypt reverse proxy data: %w", err) } - result := s.db.Select("*").Save(proxyCopy) - if result.Error != nil { - log.WithContext(ctx).Errorf("failed to update reverse proxy to store: %v", result.Error) + + // Use a transaction to ensure atomic updates of the proxy and its targets + err := s.db.Transaction(func(tx *gorm.DB) error { + // Delete existing targets + if err := tx.Where("reverse_proxy_id = ?", proxyCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil { + return err + } + + // Update the proxy and create new targets + if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(proxyCopy).Error; err != nil { + return err + } + + return nil + }) + + if err != nil { + log.WithContext(ctx).Errorf("failed to update reverse proxy to store: %v", err) return status.Errorf(status.Internal, "failed to update reverse proxy to store") } @@ -4802,7 +4883,7 @@ func (s *SqlStore) DeleteReverseProxy(ctx context.Context, accountID, proxyID st } func (s *SqlStore) GetReverseProxyByID(ctx context.Context, lockStrength LockingStrength, accountID, proxyID string) (*reverseproxy.ReverseProxy, error) { - tx := s.db + tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -4827,7 +4908,7 @@ func (s *SqlStore) GetReverseProxyByID(ctx context.Context, lockStrength Locking func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error) { var proxy *reverseproxy.ReverseProxy - result := s.db.Where("account_id = ? AND domain = ?", accountID, domain).First(&proxy) + result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&proxy) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "reverse proxy with domain %s not found", domain) @@ -4845,7 +4926,7 @@ func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domai } func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.ReverseProxy, error) { - tx := s.db + tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -4867,7 +4948,7 @@ func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingSt } func (s *SqlStore) GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error) { - tx := s.db + tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } @@ -5001,3 +5082,23 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin return logs, nil } + +func (s *SqlStore) GetReverseProxyTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var target *reverseproxy.Target + result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "reverse proxy target with ID %s not found", targetID) + } + + log.WithContext(ctx).Errorf("failed to get reverse proxy target from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get reverse proxy target from store") + } + + return target, nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index d2fd059e0..94e5a50b9 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -267,6 +267,7 @@ type Store interface { CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*accesslogs.AccessLogEntry, error) + GetReverseProxyTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) } const ( diff --git a/shared/management/status/error.go b/shared/management/status/error.go index ea02173e9..78288aef3 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -262,3 +262,11 @@ func NewZoneNotFoundError(zoneID string) error { func NewDNSRecordNotFoundError(recordID string) error { return Errorf(NotFound, "dns record: %s not found", recordID) } + +func NewResourceInUseError(resourceID string, proxyID string) error { + return Errorf(PreconditionFailed, "resource %s is in use by proxy %s", resourceID, proxyID) +} + +func NewPeerInUseError(peerID string, proxyID string) error { + return Errorf(PreconditionFailed, "peer %s is in use by proxy %s", peerID, proxyID) +}