fix path handling + extract targets to separate table + guard resource/peer deletion

This commit is contained in:
pascal
2026-02-10 17:12:34 +01:00
parent 08d3867f41
commit 0cb02bd906
11 changed files with 197 additions and 36 deletions

View File

@@ -17,4 +17,5 @@ type Manager interface {
GetGlobalReverseProxies(ctx context.Context) ([]*ReverseProxy, error)
GetProxyByID(ctx context.Context, accountID, reverseProxyID string) (*ReverseProxy, error)
GetAccountReverseProxies(ctx context.Context, accountID string) ([]*ReverseProxy, error)
GetProxyIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
}

View File

@@ -472,3 +472,16 @@ func (m *managerImpl) GetAccountReverseProxies(ctx context.Context, accountID st
return proxies, nil
}
func (m *managerImpl) GetProxyIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetReverseProxyTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
return "", fmt.Errorf("failed to get reverse proxy target by resource ID: %w", err)
}
if target == nil {
return "", nil
}
return target.ReverseProxyID, nil
}

View File

@@ -42,13 +42,16 @@ const (
)
type Target struct {
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `json:"port"`
Protocol string `json:"protocol"`
TargetId string `json:"target_id"`
TargetType string `json:"target_type"`
Enabled bool `json:"enabled"`
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ReverseProxyID string `gorm:"index:idx_reverse_proxy_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
}
type PasswordAuthConfig struct {
@@ -91,7 +94,7 @@ type ReverseProxy struct {
Name string
Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"serializer:json"`
Targets []*Target `gorm:"foreignKey:ReverseProxyID;constraint:OnDelete:CASCADE"`
Enabled bool
PassHostHeader bool
RewriteRedirects bool
@@ -102,6 +105,10 @@ type ReverseProxy struct {
}
func NewReverseProxy(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *ReverseProxy {
for _, target := range targets {
target.AccountID = accountID
}
rp := &ReverseProxy{
AccountID: accountID,
Name: name,
@@ -198,23 +205,22 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oid
continue
}
path := "/"
if target.Path != nil {
path = *target.Path
}
// TODO: Make path prefix stripping configurable per-target.
// Currently the matching prefix is baked into the target URL path,
// so the proxy strips-then-re-adds it (effectively a no-op).
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: path,
Path: "/", // TODO: support service path
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
path := "/"
if target.Path != nil {
path = *target.Path
}
pathMappings = append(pathMappings, &proto.PathMapping{
Path: path,
Target: targetURL.String(),
@@ -279,6 +285,7 @@ func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID st
targets := make([]*Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets {
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
Port: apiTarget.Port,
Protocol: string(apiTarget.Protocol),
@@ -369,7 +376,10 @@ func (r *ReverseProxy) EventMeta() map[string]any {
func (r *ReverseProxy) Copy() *ReverseProxy {
targets := make([]*Target, len(r.Targets))
copy(targets, r.Targets)
for i, target := range r.Targets {
targetCopy := *target
targets[i] = &targetCopy
}
return &ReverseProxy{
ID: r.ID,

View File

@@ -101,6 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
}
s.AfterInit(func(s *BaseServer) {
accountManager.SetReverseProxyManager(s.ReverseProxyManager())
})
return accountManager
})
}

View File

@@ -115,6 +115,10 @@ type DefaultAccountManager struct {
var _ account.Manager = (*DefaultAccountManager)(nil)
func (am *DefaultAccountManager) SetReverseProxyManager(reverseProxyManager reverseproxy.Manager) {
am.reverseProxyManager = reverseProxyManager
}
func isUniqueConstraintError(err error) bool {
switch {
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),

View File

@@ -6,6 +6,7 @@ import (
"net/netip"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/shared/auth"
nbdns "github.com/netbirdio/netbird/dns"
@@ -139,4 +140,5 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
SetReverseProxyManager(reverseProxyManager reverseproxy.Manager)
}

View File

@@ -322,6 +322,14 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return status.NewPermissionDeniedError()
}
proxyID, err := m.reverseProxyManager.GetProxyIDByTargetID(ctx, accountID, resourceID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by reverse proxy: %w", err)
}
if proxyID != "" {
return status.NewResourceInUseError(resourceID, proxyID)
}
var events []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)

View File

@@ -489,6 +489,14 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var settings *types.Settings
var eventsToStore []func()
proxyID, err := am.reverseProxyManager.GetProxyIDByTargetID(ctx, accountID, peerID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by reverse proxy: %w", err)
}
if proxyID != "" {
return status.NewPeerInUseError(peerID, proxyID)
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {

View File

@@ -131,7 +131,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.ReverseProxy{}, &domain.Domain{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.ReverseProxy{}, &reverseproxy.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{},
)
if err != nil {
@@ -2064,45 +2064,51 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
}
func (s *SqlStore) getProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) {
const query = `SELECT id, account_id, name, domain, targets, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status
const proxyQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key
FROM reverse_proxies WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
const targetsQuery = `SELECT id, account_id, reverse_proxy_id, path, host, port, protocol,
target_id, target_type, enabled
FROM targets WHERE reverse_proxy_id = ANY($1)`
proxyRows, err := s.pool.Query(ctx, proxyQuery, accountID)
if err != nil {
return nil, err
}
proxies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*reverseproxy.ReverseProxy, error) {
proxies, err := pgx.CollectRows(proxyRows, func(row pgx.CollectableRow) (*reverseproxy.ReverseProxy, error) {
var p reverseproxy.ReverseProxy
var auth []byte
var targets []byte
var createdAt, certIssuedAt sql.NullTime
var status sql.NullString
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
err := row.Scan(
&p.ID,
&p.AccountID,
&p.Name,
&p.Domain,
&targets,
&p.Enabled,
&auth,
&createdAt,
&certIssuedAt,
&status,
&proxyCluster,
&p.PassHostHeader,
&p.RewriteRedirects,
&sessionPrivateKey,
&sessionPublicKey,
)
if err != nil {
return nil, err
}
// Unmarshal JSON fields
if auth != nil {
if err := json.Unmarshal(auth, &p.Auth); err != nil {
return nil, err
}
}
if targets != nil {
if err := json.Unmarshal(targets, &p.Targets); err != nil {
return nil, err
}
}
p.Meta = reverseproxy.ReverseProxyMeta{}
if createdAt.Valid {
p.Meta.CreatedAt = createdAt.Time
@@ -2113,12 +2119,72 @@ func (s *SqlStore) getProxies(ctx context.Context, accountID string) ([]*reverse
if status.Valid {
p.Meta.Status = status.String
}
if proxyCluster.Valid {
p.ProxyCluster = proxyCluster.String
}
if sessionPrivateKey.Valid {
p.SessionPrivateKey = sessionPrivateKey.String
}
if sessionPublicKey.Valid {
p.SessionPublicKey = sessionPublicKey.String
}
p.Targets = []*reverseproxy.Target{}
return &p, nil
})
if err != nil {
return nil, err
}
if len(proxies) == 0 {
return proxies, nil
}
proxyIDs := make([]string, len(proxies))
proxyMap := make(map[string]*reverseproxy.ReverseProxy)
for i, p := range proxies {
proxyIDs[i] = p.ID
proxyMap[p.ID] = p
}
targetRows, err := s.pool.Query(ctx, targetsQuery, proxyIDs)
if err != nil {
return nil, err
}
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) {
var t reverseproxy.Target
var path sql.NullString
err := row.Scan(
&t.ID,
&t.AccountID,
&t.ReverseProxyID,
&path,
&t.Host,
&t.Port,
&t.Protocol,
&t.TargetId,
&t.TargetType,
&t.Enabled,
)
if err != nil {
return nil, err
}
if path.Valid {
t.Path = &path.String
}
return &t, nil
})
if err != nil {
return nil, err
}
for _, target := range targets {
if proxy, ok := proxyMap[target.ReverseProxyID]; ok {
proxy.Targets = append(proxy.Targets, target)
}
}
return proxies, nil
}
@@ -4778,9 +4844,24 @@ func (s *SqlStore) UpdateReverseProxy(ctx context.Context, proxy *reverseproxy.R
if err := proxyCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt reverse proxy data: %w", err)
}
result := s.db.Select("*").Save(proxyCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update reverse proxy to store: %v", result.Error)
// Use a transaction to ensure atomic updates of the proxy and its targets
err := s.db.Transaction(func(tx *gorm.DB) error {
// Delete existing targets
if err := tx.Where("reverse_proxy_id = ?", proxyCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
return err
}
// Update the proxy and create new targets
if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(proxyCopy).Error; err != nil {
return err
}
return nil
})
if err != nil {
log.WithContext(ctx).Errorf("failed to update reverse proxy to store: %v", err)
return status.Errorf(status.Internal, "failed to update reverse proxy to store")
}
@@ -4802,7 +4883,7 @@ func (s *SqlStore) DeleteReverseProxy(ctx context.Context, accountID, proxyID st
}
func (s *SqlStore) GetReverseProxyByID(ctx context.Context, lockStrength LockingStrength, accountID, proxyID string) (*reverseproxy.ReverseProxy, error) {
tx := s.db
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -4827,7 +4908,7 @@ func (s *SqlStore) GetReverseProxyByID(ctx context.Context, lockStrength Locking
func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error) {
var proxy *reverseproxy.ReverseProxy
result := s.db.Where("account_id = ? AND domain = ?", accountID, domain).First(&proxy)
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&proxy)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "reverse proxy with domain %s not found", domain)
@@ -4845,7 +4926,7 @@ func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domai
}
func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.ReverseProxy, error) {
tx := s.db
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -4867,7 +4948,7 @@ func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingSt
}
func (s *SqlStore) GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error) {
tx := s.db
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -5001,3 +5082,23 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
return logs, nil
}
func (s *SqlStore) GetReverseProxyTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var target *reverseproxy.Target
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "reverse proxy target with ID %s not found", targetID)
}
log.WithContext(ctx).Errorf("failed to get reverse proxy target from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get reverse proxy target from store")
}
return target, nil
}

View File

@@ -267,6 +267,7 @@ type Store interface {
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*accesslogs.AccessLogEntry, error)
GetReverseProxyTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
}
const (

View File

@@ -262,3 +262,11 @@ func NewZoneNotFoundError(zoneID string) error {
func NewDNSRecordNotFoundError(recordID string) error {
return Errorf(NotFound, "dns record: %s not found", recordID)
}
func NewResourceInUseError(resourceID string, proxyID string) error {
return Errorf(PreconditionFailed, "resource %s is in use by proxy %s", resourceID, proxyID)
}
func NewPeerInUseError(peerID string, proxyID string) error {
return Errorf(PreconditionFailed, "peer %s is in use by proxy %s", peerID, proxyID)
}