send updates on changes

This commit is contained in:
pascal
2026-02-09 17:06:04 +01:00
parent 73aa0785ba
commit 9a67a8e427
10 changed files with 214 additions and 62 deletions

View File

@@ -12,4 +12,9 @@ type Manager interface {
DeleteReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) error DeleteReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error
SetStatus(ctx context.Context, accountID, reverseProxyID string, status ProxyStatus) error SetStatus(ctx context.Context, accountID, reverseProxyID string, status ProxyStatus) error
ReloadAllReverseProxiesForAccount(ctx context.Context, accountID string) error
ReloadReverseProxy(ctx context.Context, accountID, reverseProxyID string) error
GetGlobalReverseProxies(ctx context.Context) ([]*ReverseProxy, error)
GetProxyByID(ctx context.Context, accountID, reverseProxyID string) (*ReverseProxy, error)
GetAccountReverseProxies(ctx context.Context, accountID string) ([]*ReverseProxy, error)
} }

View File

@@ -54,7 +54,43 @@ func (m *managerImpl) GetAllReverseProxies(ctx context.Context, accountID, userI
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID) proxies, err := m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get reverse proxies: %w", err)
}
for _, proxy := range proxies {
err = m.replaceHostByLookup(ctx, accountID, proxy)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
}
}
return proxies, nil
}
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, proxy *reverseproxy.ReverseProxy) error {
for _, target := range proxy.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
return fmt.Errorf("failed to get peer by id: %w", err)
}
target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
return fmt.Errorf("failed to get resource by id: %w", err)
}
target.Host = resource.Address
case reverseproxy.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
} }
func (m *managerImpl) GetReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.ReverseProxy, error) { func (m *managerImpl) GetReverseProxy(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.ReverseProxy, error) {
@@ -66,7 +102,16 @@ func (m *managerImpl) GetReverseProxy(ctx context.Context, accountID, userID, re
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, reverseProxyID) proxy, err := m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, reverseProxyID)
if err != nil {
return nil, fmt.Errorf("failed to get reverse proxy: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, proxy)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
}
return proxy, nil
} }
func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) { func (m *managerImpl) CreateReverseProxy(ctx context.Context, accountID, userID string, reverseProxy *reverseproxy.ReverseProxy) (*reverseproxy.ReverseProxy, error) {
@@ -149,6 +194,7 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID
var oldCluster string var oldCluster string
var domainChanged bool var domainChanged bool
var reverseProxyEnabledChanged bool
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingReverseProxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxy.ID) existingReverseProxy, err := transaction.GetReverseProxyByID(ctx, store.LockingStrengthUpdate, accountID, reverseProxy.ID)
@@ -184,6 +230,7 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID
reverseProxy.Meta = existingReverseProxy.Meta reverseProxy.Meta = existingReverseProxy.Meta
reverseProxy.SessionPrivateKey = existingReverseProxy.SessionPrivateKey reverseProxy.SessionPrivateKey = existingReverseProxy.SessionPrivateKey
reverseProxy.SessionPublicKey = existingReverseProxy.SessionPublicKey reverseProxy.SessionPublicKey = existingReverseProxy.SessionPublicKey
reverseProxyEnabledChanged = existingReverseProxy.Enabled != reverseProxy.Enabled
if err = validateTargetReferences(ctx, transaction, accountID, reverseProxy.Targets); err != nil { if err = validateTargetReferences(ctx, transaction, accountID, reverseProxy.Targets); err != nil {
return err return err
@@ -201,12 +248,19 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID
m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyUpdated, reverseProxy.EventMeta()) m.accountManager.StoreEvent(ctx, userID, reverseProxy.ID, accountID, activity.ReverseProxyUpdated, reverseProxy.EventMeta())
token, err := m.tokenStore.GenerateToken(accountID, reverseProxy.ID, 5*time.Minute)
if err != nil {
return nil, fmt.Errorf("failed to generate authentication token: %w", err)
}
switch { switch {
case domainChanged && oldCluster != reverseProxy.ProxyCluster: case domainChanged && oldCluster != reverseProxy.ProxyCluster:
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), oldCluster) m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), oldCluster)
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster) m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
case !reverseProxy.Enabled: case !reverseProxy.Enabled && reverseProxyEnabledChanged:
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster) m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
case reverseProxy.Enabled && reverseProxyEnabledChanged:
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Create, token, m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
default: default:
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster) m.proxyGRPCServer.SendReverseProxyUpdateToCluster(reverseProxy.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), reverseProxy.ProxyCluster)
@@ -217,7 +271,7 @@ func (m *managerImpl) UpdateReverseProxy(ctx context.Context, accountID, userID
} }
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account. // validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []reverseproxy.Target) error { func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
for _, target := range targets { for _, target := range targets {
switch target.TargetType { switch target.TargetType {
case reverseproxy.TargetTypePeer: case reverseproxy.TargetTypePeer:
@@ -311,3 +365,86 @@ func (m *managerImpl) SetStatus(ctx context.Context, accountID, reverseProxyID s
return nil return nil
}) })
} }
func (m *managerImpl) ReloadReverseProxy(ctx context.Context, accountID, reverseProxyID string) error {
proxy, err := m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, reverseProxyID)
if err != nil {
return fmt.Errorf("failed to get reverse proxy: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, proxy)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
}
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(proxy.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), proxy.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) ReloadAllReverseProxiesForAccount(ctx context.Context, accountID string) error {
proxies, err := m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get reverse proxies: %w", err)
}
for _, proxy := range proxies {
err = m.replaceHostByLookup(ctx, accountID, proxy)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
}
m.proxyGRPCServer.SendReverseProxyUpdateToCluster(proxy.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), proxy.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) GetGlobalReverseProxies(ctx context.Context) ([]*reverseproxy.ReverseProxy, error) {
proxies, err := m.store.GetReverseProxies(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get reverse proxies: %w", err)
}
for _, proxy := range proxies {
err = m.replaceHostByLookup(ctx, proxy.AccountID, proxy)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
}
}
return proxies, nil
}
func (m *managerImpl) GetProxyByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.ReverseProxy, error) {
proxy, err := m.store.GetReverseProxyByID(ctx, store.LockingStrengthNone, accountID, reverseProxyID)
if err != nil {
return nil, fmt.Errorf("failed to get reverse proxy: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, proxy)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
}
return proxy, nil
}
func (m *managerImpl) GetAccountReverseProxies(ctx context.Context, accountID string) ([]*reverseproxy.ReverseProxy, error) {
proxies, err := m.store.GetAccountReverseProxies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get reverse proxies: %w", err)
}
for _, proxy := range proxies {
err = m.replaceHostByLookup(ctx, accountID, proxy)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for proxy %s: %w", proxy.ID, err)
}
}
return proxies, nil
}

View File

@@ -43,7 +43,7 @@ const (
type Target struct { type Target struct {
Path *string `json:"path,omitempty"` Path *string `json:"path,omitempty"`
Host string `json:"host"` Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `json:"port"` Port int `json:"port"`
Protocol string `json:"protocol"` Protocol string `json:"protocol"`
TargetId string `json:"target_id"` TargetId string `json:"target_id"`
@@ -92,9 +92,9 @@ type ReverseProxy struct {
ID string `gorm:"primaryKey"` ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"` AccountID string `gorm:"index"`
Name string Name string
Domain string `gorm:"index"` Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"` ProxyCluster string `gorm:"index"`
Targets []Target `gorm:"serializer:json"` Targets []*Target `gorm:"serializer:json"`
Enabled bool Enabled bool
PassHostHeader bool PassHostHeader bool
RewriteRedirects bool RewriteRedirects bool
@@ -104,7 +104,7 @@ type ReverseProxy struct {
SessionPublicKey string `gorm:"column:session_public_key"` SessionPublicKey string `gorm:"column:session_public_key"`
} }
func NewReverseProxy(accountID, name, domain, proxyCluster string, targets []Target, enabled bool) *ReverseProxy { func NewReverseProxy(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *ReverseProxy {
rp := &ReverseProxy{ rp := &ReverseProxy{
AccountID: accountID, AccountID: accountID,
Name: name, Name: name,
@@ -157,7 +157,7 @@ func (r *ReverseProxy) ToAPIResponse() *api.ReverseProxy {
for _, target := range r.Targets { for _, target := range r.Targets {
apiTargets = append(apiTargets, api.ReverseProxyTarget{ apiTargets = append(apiTargets, api.ReverseProxyTarget{
Path: target.Path, Path: target.Path,
Host: target.Host, Host: &target.Host,
Port: target.Port, Port: target.Port,
Protocol: api.ReverseProxyTargetProtocol(target.Protocol), Protocol: api.ReverseProxyTargetProtocol(target.Protocol),
TargetId: target.TargetId, TargetId: target.TargetId,
@@ -280,19 +280,22 @@ func (r *ReverseProxy) FromAPIRequest(req *api.ReverseProxyRequest, accountID st
r.Domain = req.Domain r.Domain = req.Domain
r.AccountID = accountID r.AccountID = accountID
targets := make([]Target, 0, len(req.Targets)) targets := make([]*Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets { for _, apiTarget := range req.Targets {
accessLocal := apiTarget.AccessLocal != nil && *apiTarget.AccessLocal accessLocal := apiTarget.AccessLocal != nil && *apiTarget.AccessLocal
targets = append(targets, Target{ target := &Target{
Path: apiTarget.Path, Path: apiTarget.Path,
Host: apiTarget.Host,
Port: apiTarget.Port, Port: apiTarget.Port,
Protocol: string(apiTarget.Protocol), Protocol: string(apiTarget.Protocol),
TargetId: apiTarget.TargetId, TargetId: apiTarget.TargetId,
TargetType: string(apiTarget.TargetType), TargetType: string(apiTarget.TargetType),
Enabled: apiTarget.Enabled, Enabled: apiTarget.Enabled,
AccessLocal: accessLocal, AccessLocal: accessLocal,
}) }
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
targets = append(targets, target)
} }
r.Targets = targets r.Targets = targets
@@ -349,8 +352,14 @@ func (r *ReverseProxy) Validate() error {
for i, target := range r.Targets { for i, target := range r.Targets {
switch target.TargetType { switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeSubnet, TargetTypeDomain: case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// valid resource types if target.Host != "" {
return fmt.Errorf("target %d has host specified but target_type is %q", i, target.TargetType)
}
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
default: default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType) return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
} }
@@ -367,7 +376,7 @@ func (r *ReverseProxy) EventMeta() map[string]any {
} }
func (r *ReverseProxy) Copy() *ReverseProxy { func (r *ReverseProxy) Copy() *ReverseProxy {
targets := make([]Target, len(r.Targets)) targets := make([]*Target, len(r.Targets))
copy(targets, r.Targets) copy(targets, r.Targets)
return &ReverseProxy{ return &ReverseProxy{

View File

@@ -163,7 +163,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.Store(), s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager()) proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager())
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ReverseProxyManager()) proxyService.SetProxyManager(s.ReverseProxyManager())
}) })

View File

@@ -157,7 +157,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
func (s *BaseServer) ResourcesManager() resources.Manager { func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager { return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager()) return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
}) })
} }

View File

@@ -26,7 +26,6 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store"
proxyauth "github.com/netbirdio/netbird/proxy/auth" proxyauth "github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -42,17 +41,6 @@ type ProxyOIDCConfig struct {
KeysLocation string KeysLocation string
} }
type reverseProxyStore interface {
GetReverseProxies(ctx context.Context, lockStrength store.LockingStrength) ([]*reverseproxy.ReverseProxy, error)
GetAccountReverseProxies(ctx context.Context, lockStrength store.LockingStrength, accountID string) ([]*reverseproxy.ReverseProxy, error)
GetReverseProxyByID(ctx context.Context, lockStrength store.LockingStrength, accountID string, serviceID string) (*reverseproxy.ReverseProxy, error)
}
type reverseProxyManager interface {
SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error
SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error
}
// ClusterInfo contains information about a proxy cluster. // ClusterInfo contains information about a proxy cluster.
type ClusterInfo struct { type ClusterInfo struct {
Address string Address string
@@ -72,14 +60,11 @@ type ProxyServiceServer struct {
// Channel for broadcasting reverse proxy updates to all proxies // Channel for broadcasting reverse proxy updates to all proxies
updatesChan chan *proto.ProxyMapping updatesChan chan *proto.ProxyMapping
// Store of reverse proxies
reverseProxyStore reverseProxyStore
// Manager for access logs // Manager for access logs
accessLogManager accesslogs.Manager accessLogManager accesslogs.Manager
// Manager for reverse proxy operations // Manager for reverse proxy operations
reverseProxyManager reverseProxyManager reverseProxyManager reverseproxy.Manager
// Manager for peers // Manager for peers
peersManager peers.Manager peersManager peers.Manager
@@ -106,18 +91,17 @@ type proxyConnection struct {
} }
// NewProxyServiceServer creates a new proxy service server // NewProxyServiceServer creates a new proxy service server
func NewProxyServiceServer(store reverseProxyStore, accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager) *ProxyServiceServer { func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager) *ProxyServiceServer {
return &ProxyServiceServer{ return &ProxyServiceServer{
updatesChan: make(chan *proto.ProxyMapping, 100), updatesChan: make(chan *proto.ProxyMapping, 100),
reverseProxyStore: store, accessLogManager: accessLogMgr,
accessLogManager: accessLogMgr, oidcConfig: oidcConfig,
oidcConfig: oidcConfig, tokenStore: tokenStore,
tokenStore: tokenStore, peersManager: peersManager,
peersManager: peersManager,
} }
} }
func (s *ProxyServiceServer) SetProxyManager(manager reverseProxyManager) { func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) {
s.reverseProxyManager = manager s.reverseProxyManager = manager
} }
@@ -189,7 +173,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
// sendSnapshot sends the initial snapshot of reverse proxies to the connecting proxy. // sendSnapshot sends the initial snapshot of reverse proxies to the connecting proxy.
// Only reverse proxies matching the proxy's cluster address are sent. // Only reverse proxies matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
reverseProxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone) reverseProxies, err := s.reverseProxyManager.GetGlobalReverseProxies(ctx)
if err != nil { if err != nil {
return fmt.Errorf("get reverse proxies from store: %w", err) return fmt.Errorf("get reverse proxies from store: %w", err)
} }
@@ -434,7 +418,7 @@ func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
} }
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
proxy, err := s.reverseProxyStore.GetReverseProxyByID(ctx, store.LockingStrengthNone, req.GetAccountId(), req.GetId()) proxy, err := s.reverseProxyManager.GetProxyByID(ctx, req.GetAccountId(), req.GetId())
if err != nil { if err != nil {
// TODO: log the error // TODO: log the error
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
@@ -613,7 +597,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
return nil, status.Errorf(codes.InvalidArgument, "failed to parse redirect url: %v", err) return nil, status.Errorf(codes.InvalidArgument, "failed to parse redirect url: %v", err)
} }
// Validate redirectURL against known proxy endpoints to avoid abuse of OIDC redirection. // Validate redirectURL against known proxy endpoints to avoid abuse of OIDC redirection.
proxies, err := s.reverseProxyStore.GetAccountReverseProxies(ctx, store.LockingStrengthNone, req.GetAccountId()) proxies, err := s.reverseProxyManager.GetAccountReverseProxies(ctx, req.GetAccountId())
if err != nil { if err != nil {
// TODO: log // TODO: log
return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "failed to get reverse proxy from store: %v", err)
@@ -719,7 +703,7 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user. // GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the proxy by domain to get its signing key // Find the proxy by domain to get its signing key
proxies, err := s.reverseProxyStore.GetReverseProxies(ctx, store.LockingStrengthNone) proxies, err := s.reverseProxyManager.GetGlobalReverseProxies(ctx)
if err != nil { if err != nil {
return "", fmt.Errorf("get reverse proxies: %w", err) return "", fmt.Errorf("get reverse proxies: %w", err)
} }

View File

@@ -15,6 +15,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
@@ -82,8 +83,9 @@ type DefaultAccountManager struct {
requestBuffer *AccountRequestBuffer requestBuffer *AccountRequestBuffer
proxyController port_forwarding.Controller proxyController port_forwarding.Controller
settingsManager settings.Manager settingsManager settings.Manager
reverseProxyManager reverseproxy.Manager
// config contains the management server configuration // config contains the management server configuration
config *nbconfig.Config config *nbconfig.Config
@@ -321,6 +323,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil { if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err return err
} }
if err = am.reverseProxyManager.ReloadAllReverseProxiesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all reverse proxy for account %s: %v", accountID, err)
}
updateAccountPeers = true updateAccountPeers = true
} }

View File

@@ -5,6 +5,9 @@ import (
"errors" "errors"
"fmt" "fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
@@ -30,21 +33,23 @@ type Manager interface {
} }
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
permissionsManager permissions.Manager permissionsManager permissions.Manager
groupsManager groups.Manager groupsManager groups.Manager
accountManager account.Manager accountManager account.Manager
reverseProxyManager reverseproxy.Manager
} }
type mockManager struct { type mockManager struct {
} }
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager) Manager { func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
groupsManager: groupsManager, groupsManager: groupsManager,
accountManager: accountManager, accountManager: accountManager,
reverseProxyManager: reverseproxyManager,
} }
} }
@@ -257,6 +262,14 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
event() event()
} }
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
go func() {
err := m.reverseProxyManager.ReloadAllReverseProxiesForAccount(ctx, resource.AccountID)
if err != nil {
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
}
}()
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
return resource, nil return resource, nil

View File

@@ -2972,7 +2972,6 @@ components:
- target_id - target_id
- target_type - target_type
- protocol - protocol
- host
- port - port
- enabled - enabled
ReverseProxyAuthConfig: ReverseProxyAuthConfig:

View File

@@ -2084,12 +2084,12 @@ type ReverseProxyTarget struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Host Backend ip or domain for this target // Host Backend ip or domain for this target
Host string `json:"host"` Host *string `json:"host,omitempty"`
// Path URL path prefix for this target // Path URL path prefix for this target
Path *string `json:"path,omitempty"` Path *string `json:"path,omitempty"`
// Port Backend port for this target // Port Backend port for this target. Use 0 or omit to use the scheme default (80 for http, 443 for https).
Port int `json:"port"` Port int `json:"port"`
// Protocol Protocol to use when connecting to the backend // Protocol Protocol to use when connecting to the backend