mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
fix path handling + extract targets to separate table + guard resource/peer deletion
This commit is contained in:
@@ -17,4 +17,5 @@ type Manager interface {
|
||||
GetGlobalReverseProxies(ctx context.Context) ([]*ReverseProxy, error)
|
||||
GetProxyByID(ctx context.Context, accountID, reverseProxyID string) (*ReverseProxy, error)
|
||||
GetAccountReverseProxies(ctx context.Context, accountID string) ([]*ReverseProxy, error)
|
||||
GetProxyIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
||||
}
|
||||
|
||||
@@ -472,3 +472,16 @@ func (m *managerImpl) GetAccountReverseProxies(ctx context.Context, accountID st
|
||||
|
||||
return proxies, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetProxyIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetReverseProxyTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get reverse proxy target by resource ID: %w", err)
|
||||
}
|
||||
|
||||
if target == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return target.ReverseProxyID, nil
|
||||
}
|
||||
|
||||
@@ -42,13 +42,16 @@ const (
|
||||
)
|
||||
|
||||
type Target struct {
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
Port int `json:"port"`
|
||||
Protocol string `json:"protocol"`
|
||||
TargetId string `json:"target_id"`
|
||||
TargetType string `json:"target_type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
ID uint `gorm:"primaryKey" json:"-"`
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ReverseProxyID string `gorm:"index:idx_reverse_proxy_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
Port int `gorm:"index:idx_target_port" json:"port"`
|
||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
||||
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||
}
|
||||
|
||||
type PasswordAuthConfig struct {
|
||||
@@ -91,7 +94,7 @@ type ReverseProxy struct {
|
||||
Name string
|
||||
Domain string `gorm:"index"`
|
||||
ProxyCluster string `gorm:"index"`
|
||||
Targets []*Target `gorm:"serializer:json"`
|
||||
Targets []*Target `gorm:"foreignKey:ReverseProxyID;constraint:OnDelete:CASCADE"`
|
||||
Enabled bool
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
@@ -102,6 +105,10 @@ type ReverseProxy struct {
|
||||
}
|
||||
|
||||
func NewReverseProxy(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *ReverseProxy {
|
||||
for _, target := range targets {
|
||||
target.AccountID = accountID
|
||||
}
|
||||
|
||||
rp := &ReverseProxy{
|
||||
AccountID: accountID,
|
||||
Name: name,
|
||||
@@ -198,23 +205,22 @@ func (r *ReverseProxy) ToProtoMapping(operation Operation, authToken string, oid
|
||||
continue
|
||||
}
|
||||
|
||||
path := "/"
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
|
||||
// TODO: Make path prefix stripping configurable per-target.
|
||||
// Currently the matching prefix is baked into the target URL path,
|
||||
// so the proxy strips-then-re-adds it (effectively a no-op).
|
||||
targetURL := url.URL{
|
||||
Scheme: target.Protocol,
|
||||
Host: target.Host,
|
||||
Path: path,
|
||||
Path: "/", // TODO: support service path
|
||||
}
|
||||
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
|
||||
}
|
||||
|
||||
path := "/"
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
pathMappings = append(pathMappings, &proto.PathMapping{
|
||||
Path: path,
|
||||
Target: targetURL.String(),
|
||||
@@ -279,6 +285,7 @@ func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID st
|
||||
targets := make([]*Target, 0, len(req.Targets))
|
||||
for _, apiTarget := range req.Targets {
|
||||
target := &Target{
|
||||
AccountID: accountID,
|
||||
Path: apiTarget.Path,
|
||||
Port: apiTarget.Port,
|
||||
Protocol: string(apiTarget.Protocol),
|
||||
@@ -369,7 +376,10 @@ func (r *ReverseProxy) EventMeta() map[string]any {
|
||||
|
||||
func (r *ReverseProxy) Copy() *ReverseProxy {
|
||||
targets := make([]*Target, len(r.Targets))
|
||||
copy(targets, r.Targets)
|
||||
for i, target := range r.Targets {
|
||||
targetCopy := *target
|
||||
targets[i] = &targetCopy
|
||||
}
|
||||
|
||||
return &ReverseProxy{
|
||||
ID: r.ID,
|
||||
|
||||
@@ -101,6 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create account manager: %v", err)
|
||||
}
|
||||
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
accountManager.SetReverseProxyManager(s.ReverseProxyManager())
|
||||
})
|
||||
|
||||
return accountManager
|
||||
})
|
||||
}
|
||||
|
||||
@@ -115,6 +115,10 @@ type DefaultAccountManager struct {
|
||||
|
||||
var _ account.Manager = (*DefaultAccountManager)(nil)
|
||||
|
||||
func (am *DefaultAccountManager) SetReverseProxyManager(reverseProxyManager reverseproxy.Manager) {
|
||||
am.reverseProxyManager = reverseProxyManager
|
||||
}
|
||||
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
switch {
|
||||
case strings.Contains(err.Error(), "(SQLSTATE 23505)"),
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -139,4 +140,5 @@ type Manager interface {
|
||||
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||
SetReverseProxyManager(reverseProxyManager reverseproxy.Manager)
|
||||
}
|
||||
|
||||
@@ -322,6 +322,14 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
proxyID, err := m.reverseProxyManager.GetProxyIDByTargetID(ctx, accountID, resourceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if resource is used by reverse proxy: %w", err)
|
||||
}
|
||||
if proxyID != "" {
|
||||
return status.NewResourceInUseError(resourceID, proxyID)
|
||||
}
|
||||
|
||||
var events []func()
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
|
||||
|
||||
@@ -489,6 +489,14 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
var settings *types.Settings
|
||||
var eventsToStore []func()
|
||||
|
||||
proxyID, err := am.reverseProxyManager.GetProxyIDByTargetID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if resource is used by reverse proxy: %w", err)
|
||||
}
|
||||
if proxyID != "" {
|
||||
return status.NewPeerInUseError(peerID, proxyID)
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
|
||||
@@ -131,7 +131,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.ReverseProxy{}, &domain.Domain{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.ReverseProxy{}, &reverseproxy.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{},
|
||||
)
|
||||
if err != nil {
|
||||
@@ -2064,45 +2064,51 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
|
||||
}
|
||||
|
||||
func (s *SqlStore) getProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) {
|
||||
const query = `SELECT id, account_id, name, domain, targets, enabled, auth,
|
||||
meta_created_at, meta_certificate_issued_at, meta_status
|
||||
const proxyQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
||||
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key
|
||||
FROM reverse_proxies WHERE account_id = $1`
|
||||
rows, err := s.pool.Query(ctx, query, accountID)
|
||||
|
||||
const targetsQuery = `SELECT id, account_id, reverse_proxy_id, path, host, port, protocol,
|
||||
target_id, target_type, enabled
|
||||
FROM targets WHERE reverse_proxy_id = ANY($1)`
|
||||
|
||||
proxyRows, err := s.pool.Query(ctx, proxyQuery, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
proxies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*reverseproxy.ReverseProxy, error) {
|
||||
|
||||
proxies, err := pgx.CollectRows(proxyRows, func(row pgx.CollectableRow) (*reverseproxy.ReverseProxy, error) {
|
||||
var p reverseproxy.ReverseProxy
|
||||
var auth []byte
|
||||
var targets []byte
|
||||
var createdAt, certIssuedAt sql.NullTime
|
||||
var status sql.NullString
|
||||
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
||||
err := row.Scan(
|
||||
&p.ID,
|
||||
&p.AccountID,
|
||||
&p.Name,
|
||||
&p.Domain,
|
||||
&targets,
|
||||
&p.Enabled,
|
||||
&auth,
|
||||
&createdAt,
|
||||
&certIssuedAt,
|
||||
&status,
|
||||
&proxyCluster,
|
||||
&p.PassHostHeader,
|
||||
&p.RewriteRedirects,
|
||||
&sessionPrivateKey,
|
||||
&sessionPublicKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Unmarshal JSON fields
|
||||
|
||||
if auth != nil {
|
||||
if err := json.Unmarshal(auth, &p.Auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if targets != nil {
|
||||
if err := json.Unmarshal(targets, &p.Targets); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
p.Meta = reverseproxy.ReverseProxyMeta{}
|
||||
if createdAt.Valid {
|
||||
p.Meta.CreatedAt = createdAt.Time
|
||||
@@ -2113,12 +2119,72 @@ func (s *SqlStore) getProxies(ctx context.Context, accountID string) ([]*reverse
|
||||
if status.Valid {
|
||||
p.Meta.Status = status.String
|
||||
}
|
||||
if proxyCluster.Valid {
|
||||
p.ProxyCluster = proxyCluster.String
|
||||
}
|
||||
if sessionPrivateKey.Valid {
|
||||
p.SessionPrivateKey = sessionPrivateKey.String
|
||||
}
|
||||
if sessionPublicKey.Valid {
|
||||
p.SessionPublicKey = sessionPublicKey.String
|
||||
}
|
||||
|
||||
p.Targets = []*reverseproxy.Target{}
|
||||
return &p, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(proxies) == 0 {
|
||||
return proxies, nil
|
||||
}
|
||||
|
||||
proxyIDs := make([]string, len(proxies))
|
||||
proxyMap := make(map[string]*reverseproxy.ReverseProxy)
|
||||
for i, p := range proxies {
|
||||
proxyIDs[i] = p.ID
|
||||
proxyMap[p.ID] = p
|
||||
}
|
||||
|
||||
targetRows, err := s.pool.Query(ctx, targetsQuery, proxyIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) {
|
||||
var t reverseproxy.Target
|
||||
var path sql.NullString
|
||||
err := row.Scan(
|
||||
&t.ID,
|
||||
&t.AccountID,
|
||||
&t.ReverseProxyID,
|
||||
&path,
|
||||
&t.Host,
|
||||
&t.Port,
|
||||
&t.Protocol,
|
||||
&t.TargetId,
|
||||
&t.TargetType,
|
||||
&t.Enabled,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if path.Valid {
|
||||
t.Path = &path.String
|
||||
}
|
||||
return &t, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
if proxy, ok := proxyMap[target.ReverseProxyID]; ok {
|
||||
proxy.Targets = append(proxy.Targets, target)
|
||||
}
|
||||
}
|
||||
|
||||
return proxies, nil
|
||||
}
|
||||
|
||||
@@ -4778,9 +4844,24 @@ func (s *SqlStore) UpdateReverseProxy(ctx context.Context, proxy *reverseproxy.R
|
||||
if err := proxyCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt reverse proxy data: %w", err)
|
||||
}
|
||||
result := s.db.Select("*").Save(proxyCopy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update reverse proxy to store: %v", result.Error)
|
||||
|
||||
// Use a transaction to ensure atomic updates of the proxy and its targets
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Delete existing targets
|
||||
if err := tx.Where("reverse_proxy_id = ?", proxyCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the proxy and create new targets
|
||||
if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(proxyCopy).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update reverse proxy to store: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to update reverse proxy to store")
|
||||
}
|
||||
|
||||
@@ -4802,7 +4883,7 @@ func (s *SqlStore) DeleteReverseProxy(ctx context.Context, accountID, proxyID st
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetReverseProxyByID(ctx context.Context, lockStrength LockingStrength, accountID, proxyID string) (*reverseproxy.ReverseProxy, error) {
|
||||
tx := s.db
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -4827,7 +4908,7 @@ func (s *SqlStore) GetReverseProxyByID(ctx context.Context, lockStrength Locking
|
||||
|
||||
func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.ReverseProxy, error) {
|
||||
var proxy *reverseproxy.ReverseProxy
|
||||
result := s.db.Where("account_id = ? AND domain = ?", accountID, domain).First(&proxy)
|
||||
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&proxy)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "reverse proxy with domain %s not found", domain)
|
||||
@@ -4845,7 +4926,7 @@ func (s *SqlStore) GetReverseProxyByDomain(ctx context.Context, accountID, domai
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.ReverseProxy, error) {
|
||||
tx := s.db
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -4867,7 +4948,7 @@ func (s *SqlStore) GetReverseProxies(ctx context.Context, lockStrength LockingSt
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountReverseProxies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error) {
|
||||
tx := s.db
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -5001,3 +5082,23 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
|
||||
|
||||
return logs, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetReverseProxyTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var target *reverseproxy.Target
|
||||
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "reverse proxy target with ID %s not found", targetID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Errorf("failed to get reverse proxy target from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get reverse proxy target from store")
|
||||
}
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
@@ -267,6 +267,7 @@ type Store interface {
|
||||
|
||||
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
|
||||
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*accesslogs.AccessLogEntry, error)
|
||||
GetReverseProxyTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -262,3 +262,11 @@ func NewZoneNotFoundError(zoneID string) error {
|
||||
func NewDNSRecordNotFoundError(recordID string) error {
|
||||
return Errorf(NotFound, "dns record: %s not found", recordID)
|
||||
}
|
||||
|
||||
func NewResourceInUseError(resourceID string, proxyID string) error {
|
||||
return Errorf(PreconditionFailed, "resource %s is in use by proxy %s", resourceID, proxyID)
|
||||
}
|
||||
|
||||
func NewPeerInUseError(peerID string, proxyID string) error {
|
||||
return Errorf(PreconditionFailed, "peer %s is in use by proxy %s", peerID, proxyID)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user