Compare commits

...

12 Commits

Author SHA1 Message Date
mlsmaycon
dc136c17e9 Merge branch 'main' into feature/add-serial-to-proxy 2026-02-24 18:09:14 +01:00
pascal
b450fa2cca fix tests 2026-02-24 10:32:12 +01:00
pascal
33cda4d10c Merge branch 'main' into feature/add-serial-to-proxy 2026-02-23 17:17:06 +01:00
pascal
6df57623dd Merge remote-tracking branch 'origin/main' into feature/add-serial-to-proxy 2026-02-23 16:54:22 +01:00
pascal
740c726a78 extract proxy controller 2026-02-23 15:28:20 +01:00
pascal
3af287ebab move service manager 2026-02-20 01:21:05 +01:00
pascal
d4d885d434 Merge remote-tracking branch 'origin/main' into feature/add-serial-to-proxy 2026-02-20 00:35:10 +01:00
pascal
d212332f5d store proxies in DB 2026-02-20 00:28:45 +01:00
pascal
0e11258e97 fix test 2026-02-19 13:30:26 +01:00
pascal
31ecf8f1f5 allow redis for token store 2026-02-19 13:28:36 +01:00
pascal
e2df1fb35e add accountID when sending update to cluster 2026-02-19 12:02:31 +01:00
pascal
942cd5dc72 export methods 2026-02-19 10:50:32 +01:00
44 changed files with 1352 additions and 846 deletions

View File

@@ -27,21 +27,21 @@ type store interface {
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
}
type proxyURLProvider interface {
GetConnectedProxyURLs() []string
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
}
type Manager struct {
store store
validator domain.Validator
proxyURLProvider proxyURLProvider
proxyManager proxyManager
permissionsManager permissions.Manager
}
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
return Manager{
store: store,
proxyURLProvider: proxyURLProvider,
store: store,
proxyManager: proxyMgr,
validator: domain.Validator{
Resolver: net.DefaultResolver,
},
@@ -67,8 +67,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList := m.proxyURLAllowList()
log.WithFields(log.Fields{
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
log.WithContext(ctx).WithFields(log.Fields{
"accountID": accountID,
"proxyAllowList": allowList,
}).Debug("getting domains with proxy allow list")
@@ -107,7 +111,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
}
// Verify the target cluster is in the available clusters
allowList := m.proxyURLAllowList()
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
clusterValid := false
for _, cluster := range allowList {
if cluster == targetCluster {
@@ -221,25 +228,26 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
}
}
// GetClusterDomains returns a list of proxy cluster domains.
func (m Manager) GetClusterDomains() []string {
return m.proxyURLAllowList()
}
// proxyURLAllowList retrieves a list of currently connected proxies and
// their URLs
func (m Manager) proxyURLAllowList() []string {
var reverseProxyAddresses []string
if m.proxyURLProvider != nil {
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
if m.proxyManager == nil {
return nil
}
return reverseProxyAddresses
addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background())
if err != nil {
return nil
}
return addresses
}
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList := m.proxyURLAllowList()
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
if len(allowList) == 0 {
return "", fmt.Errorf("no proxy clusters available")
}

View File

@@ -0,0 +1,15 @@
package proxy
import (
"context"
"time"
)
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
}

View File

@@ -0,0 +1,106 @@
package manager
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// store defines the interface for proxy persistence operations
type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
}
// Manager handles all proxy operations
type Manager struct {
store store
}
// NewManager creates a new proxy Manager
func NewManager(store store) Manager {
return Manager{
store: store,
}
}
// Connect registers a new proxy connection in the database
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
"clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return nil
}
// Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
}).Info("proxy disconnected")
return nil
}
// Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
}
return nil
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
return addresses, nil
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}

View File

@@ -0,0 +1,20 @@
package proxy
import "time"
// Proxy represents a reverse proxy instance
type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (Proxy) TableName() string {
return "proxies"
}

View File

@@ -1,9 +1,11 @@
package reverseproxy
package service
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import (
"context"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Manager interface {
@@ -14,7 +16,7 @@ type Manager interface {
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
DeleteAllServices(ctx context.Context, accountID, userID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error)
@@ -26,3 +28,12 @@ type Manager interface {
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StartExposeReaper(ctx context.Context)
}
// ProxyController is responsible for managing proxy clusters and routing service updates.
type ProxyController interface {
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
GetOIDCValidationConfig() OIDCValidationConfig
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
}

View File

@@ -1,14 +1,15 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go
// Package reverseproxy is a generated GoMock package.
package reverseproxy
// Package service is a generated GoMock package.
package service
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
proto "github.com/netbirdio/netbird/shared/management/proto"
)
// MockManager is a mock of Manager interface.
@@ -239,7 +240,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
}
// SetStatus mocks base method.
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
ret0, _ := ret[0].(error)
@@ -292,3 +293,94 @@ func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
}
// MockProxyController is a mock of ProxyController interface.
type MockProxyController struct {
ctrl *gomock.Controller
recorder *MockProxyControllerMockRecorder
}
// MockProxyControllerMockRecorder is the mock recorder for MockProxyController.
type MockProxyControllerMockRecorder struct {
mock *MockProxyController
}
// NewMockProxyController creates a new mock instance.
func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController {
mock := &MockProxyController{ctrl: ctrl}
mock.recorder = &MockProxyControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder {
return m.recorder
}
// GetOIDCValidationConfig mocks base method.
func (m *MockProxyController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
ret0, _ := ret[0].(OIDCValidationConfig)
return ret0
}
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
func (mr *MockProxyControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockProxyController)(nil).GetOIDCValidationConfig))
}
// GetProxiesForCluster mocks base method.
func (m *MockProxyController) GetProxiesForCluster(clusterAddr string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
ret0, _ := ret[0].([]string)
return ret0
}
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
func (mr *MockProxyControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockProxyController)(nil).GetProxiesForCluster), clusterAddr)
}
// RegisterProxyToCluster mocks base method.
func (m *MockProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
func (mr *MockProxyControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockProxyController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
}
// SendServiceUpdateToCluster mocks base method.
func (m *MockProxyController) SendServiceUpdateToCluster(accountID string, update *proto.ProxyMapping, clusterAddr string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SendServiceUpdateToCluster", accountID, update, clusterAddr)
}
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
func (mr *MockProxyControllerMockRecorder) SendServiceUpdateToCluster(accountID, update, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockProxyController)(nil).SendServiceUpdateToCluster), accountID, update, clusterAddr)
}
// UnregisterProxyFromCluster mocks base method.
func (m *MockProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
func (mr *MockProxyControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockProxyController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
}

View File

@@ -6,10 +6,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -17,11 +17,11 @@ import (
)
type handler struct {
manager reverseproxy.Manager
manager rpservice.Manager
}
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{
manager: manager,
}
@@ -72,7 +72,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
return
}
service := new(reverseproxy.Service)
service := new(rpservice.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
@@ -130,7 +130,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
return
}
service := new(reverseproxy.Service)
service := new(rpservice.Service)
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)

View File

@@ -0,0 +1,74 @@
package manager
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/shared/management/proto"
)
// GRPCProxyController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
type GRPCProxyController struct {
proxyGRPCServer *nbgrpc.ProxyServiceServer
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
}
// NewGRPCProxyController creates a new GRPCProxyController.
func NewGRPCProxyController(proxyGRPCServer *nbgrpc.ProxyServiceServer) *GRPCProxyController {
return &GRPCProxyController{
proxyGRPCServer: proxyGRPCServer,
}
}
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
func (c *GRPCProxyController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
}
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
func (c *GRPCProxyController) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
return c.proxyGRPCServer.GetOIDCValidationConfig()
}
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
func (c *GRPCProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
return nil
}
// UnregisterProxyFromCluster removes a proxy from a cluster.
func (c *GRPCProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
}
return nil
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCProxyController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)
if !ok {
return nil
}
var proxies []string
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxies = append(proxies, key.(string))
return true
})
return proxies
}

View File

@@ -27,7 +27,7 @@ type trackedExpose struct {
type exposeTracker struct {
activeExposes sync.Map
exposeCreateMu sync.Mutex
manager *managerImpl
manager *Manager
}
func exposeKey(peerID, domain string) string {

View File

@@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
)
func TestExposeKey(t *testing.T) {
@@ -120,7 +120,7 @@ func TestReapExpiredExposes(t *testing.T) {
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
@@ -156,7 +156,7 @@ func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) {
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
@@ -191,7 +191,7 @@ func TestConcurrentTrackAndCount(t *testing.T) {
ctx := context.Background()
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})

View File

@@ -4,24 +4,21 @@ import (
"context"
"fmt"
"math/rand/v2"
"slices"
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus"
"slices"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -33,24 +30,22 @@ type ClusterDeriver interface {
GetClusterDomains() []string
}
type managerImpl struct {
type Manager struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
settingsManager settings.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
proxyController service.ProxyController
clusterDeriver ClusterDeriver
exposeTracker *exposeTracker
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, settingsManager settings.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
mgr := &managerImpl{
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController service.ProxyController, clusterDeriver ClusterDeriver) *Manager {
mgr := &Manager{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
settingsManager: settingsManager,
proxyGRPCServer: proxyGRPCServer,
proxyController: proxyController,
clusterDeriver: clusterDeriver,
}
mgr.exposeTracker = &exposeTracker{manager: mgr}
@@ -58,11 +53,11 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
}
// StartExposeReaper delegates to the expose tracker.
func (m *managerImpl) StartExposeReaper(ctx context.Context) {
func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeTracker.StartExposeReaper(ctx)
}
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -86,34 +81,34 @@ func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID stri
return services, nil
}
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
for _, target := range service.Targets {
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
for _, target := range s.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
case service.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost:
case service.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Prefix.Addr().String()
case reverseproxy.TargetTypeDomain:
case service.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Domain
case reverseproxy.TargetTypeSubnet:
case service.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
@@ -122,7 +117,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string,
return nil
}
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -143,7 +138,7 @@ func (m *managerImpl) GetService(ctx context.Context, accountID, userID, service
return service, nil
}
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -152,29 +147,29 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
return nil, status.NewPermissionDeniedError()
}
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
return nil, err
}
if err := m.persistNewService(ctx, accountID, service); err != nil {
if err := m.persistNewService(ctx, accountID, s); err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
return s, nil
}
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
if m.clusterDeriver != nil {
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
@@ -201,7 +196,7 @@ func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID
return nil
}
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
return err
@@ -219,7 +214,7 @@ func (m *managerImpl) persistNewService(ctx context.Context, accountID string, s
})
}
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
@@ -235,7 +230,7 @@ func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction stor
return nil
}
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -259,7 +254,7 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.sendServiceUpdateNotifications(service, updateInfo)
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
@@ -271,7 +266,7 @@ type serviceUpdateInfo struct {
serviceEnabledChanged bool
}
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -309,7 +304,7 @@ func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string
return &updateInfo, err
}
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
return err
}
@@ -326,7 +321,7 @@ func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.
return nil
}
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
@@ -340,54 +335,40 @@ func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reve
}
}
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyController.GetOIDCValidationConfig()
switch {
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
m.sendServiceUpdate(service, reverseproxy.Delete, updateInfo.oldCluster, "")
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
case !service.Enabled && updateInfo.serviceEnabledChanged:
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
case service.Enabled && updateInfo.serviceEnabledChanged:
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
case !s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
case s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
default:
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
}
}
func (m *managerImpl) sendServiceUpdate(service *reverseproxy.Service, operation reverseproxy.Operation, cluster, oldService string) {
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
mapping := service.ToProtoMapping(operation, oldService, oidcCfg)
m.sendMappingsToCluster([]*proto.ProxyMapping{mapping}, cluster)
}
func (m *managerImpl) sendMappingsToCluster(mappings []*proto.ProxyMapping, cluster string) {
if len(mappings) == 0 {
return
}
update := &proto.GetMappingUpdateResponse{
Mapping: mappings,
}
m.proxyGRPCServer.SendServiceUpdateToCluster(update, cluster)
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
for _, target := range targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
case service.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
@@ -399,7 +380,7 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil
}
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -408,10 +389,10 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return status.NewPermissionDeniedError()
}
var service *reverseproxy.Service
var s *service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
@@ -426,20 +407,20 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return err
}
if service.Source == reverseproxy.SourceEphemeral {
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain)
if s.Source == service.SourceEphemeral {
m.exposeTracker.UntrackExpose(s.SourcePeer, s.Domain)
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID string) error {
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -448,16 +429,16 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
return status.NewPermissionDeniedError()
}
var services []*reverseproxy.Service
var services []*service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
services, err = transaction.GetServicesByAccountID(ctx, store.LockingStrengthUpdate, accountID)
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return err
}
for _, service := range services {
if err = transaction.DeleteService(ctx, accountID, service.ID); err != nil {
for _, svc := range services {
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
}
@@ -468,20 +449,14 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
return err
}
clusterMappings := make(map[string][]*proto.ProxyMapping)
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
oidcCfg := m.proxyController.GetOIDCValidationConfig()
for _, service := range services {
if service.Source == reverseproxy.SourceEphemeral {
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain)
for _, svc := range services {
if svc.Source == service.SourceEphemeral {
m.exposeTracker.UntrackExpose(svc.SourcePeer, svc.Domain)
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, service.EventMeta())
mapping := service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg)
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
}
for cluster, mappings := range clusterMappings {
m.sendMappingsToCluster(mappings, cluster)
m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
@@ -491,7 +466,7 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
@@ -510,7 +485,7 @@ func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, ser
}
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
@@ -527,50 +502,42 @@ func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string
})
}
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get services: %w", err)
}
clusterMappings := make(map[string][]*proto.ProxyMapping)
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
for _, s := range services {
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
mapping := service.ToProtoMapping(reverseproxy.Update, "", oidcCfg)
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
}
for cluster, mappings := range clusterMappings {
m.sendMappingsToCluster(mappings, cluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
}
return nil
}
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
@@ -586,7 +553,7 @@ func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Se
return services, nil
}
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
@@ -600,7 +567,7 @@ func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID s
return service, nil
}
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
@@ -616,7 +583,7 @@ func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string)
return services, nil
}
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
@@ -634,7 +601,7 @@ func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID stri
// validateExposePermission checks whether the peer is allowed to use the expose feature.
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, peerID string) error {
func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
@@ -667,7 +634,7 @@ func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, p
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
if err := req.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
}
@@ -676,31 +643,31 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
return nil, err
}
serviceName, err := reverseproxy.GenerateExposeName(req.NamePrefix)
serviceName, err := service.GenerateExposeName(req.NamePrefix)
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
}
service := req.ToService(accountID, peerID, serviceName)
service.Source = reverseproxy.SourceEphemeral
svc := req.ToService(accountID, peerID, serviceName)
svc.Source = service.SourceEphemeral
if service.Domain == "" {
domain, err := m.buildRandomDomain(service.Name)
if svc.Domain == "" {
domain, err := m.buildRandomDomain(svc.Name)
if err != nil {
return nil, fmt.Errorf("build random domain for service %s: %w", service.Name, err)
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
}
service.Domain = domain
svc.Domain = domain
}
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, service.Auth.BearerAuth.DistributionGroups)
if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled {
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups)
if err != nil {
return nil, fmt.Errorf("get group ids for service %s: %w", service.Name, err)
return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err)
}
service.Auth.BearerAuth.DistributionGroups = groupIDs
svc.Auth.BearerAuth.DistributionGroups = groupIDs
}
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil {
return nil, err
}
@@ -710,45 +677,45 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
}
now := time.Now()
service.Meta.LastRenewedAt = &now
service.SourcePeer = peerID
svc.Meta.LastRenewedAt = &now
svc.SourcePeer = peerID
if err := m.persistNewService(ctx, accountID, service); err != nil {
if err := m.persistNewService(ctx, accountID, svc); err != nil {
return nil, err
}
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, service.Domain, accountID)
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, svc.Domain, accountID)
if alreadyTracked {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", service.Domain, err)
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", svc.Domain, err)
}
return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
if !allowed {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", service.Domain, err)
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", svc.Domain, err)
}
return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
meta := addPeerInfoToEventMeta(service.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, service.ID, accountID, activity.PeerServiceExposed, meta)
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta)
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
return nil, fmt.Errorf("replace host by lookup for service %s: %w", service.ID, err)
if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil {
return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err)
}
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return &reverseproxy.ExposeServiceResponse{
ServiceName: service.Name,
ServiceURL: "https://" + service.Domain,
Domain: service.Domain,
return &service.ExposeServiceResponse{
ServiceName: svc.Name,
ServiceURL: "https://" + svc.Domain,
Domain: svc.Domain,
}, nil
}
func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
if len(groupNames) == 0 {
return []string{}, fmt.Errorf("no group names provided")
}
@@ -763,7 +730,7 @@ func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string
return groupIDs, nil
}
func (m *managerImpl) buildRandomDomain(name string) (string, error) {
func (m *Manager) buildRandomDomain(name string) (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
}
@@ -778,7 +745,7 @@ func (m *managerImpl) buildRandomDomain(name string) (string, error) {
// RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session.
// Returns an error if the expose is not actively tracked.
func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error {
func (m *Manager) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error {
if !m.exposeTracker.RenewTrackedExpose(peerID, domain) {
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
}
@@ -786,7 +753,7 @@ func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain
}
// StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service.
func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
return err
@@ -801,8 +768,8 @@ func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
service, err := m.lookupPeerService(ctx, accountID, peerID, domain)
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
@@ -811,41 +778,41 @@ func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peer
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, service.ID, activityCode)
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *managerImpl) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*reverseproxy.Service, error) {
service, err := m.store.GetServiceByDomain(ctx, accountID, domain)
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
return nil, err
}
if service.Source != reverseproxy.SourceEphemeral {
if svc.Source != service.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if service.SourcePeer != peerID {
if svc.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return service, nil
return svc, nil
}
func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
var service *reverseproxy.Service
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
var svc *service.Service
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if service.Source != reverseproxy.SourceEphemeral {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose")
}
if service.SourcePeer != peerID {
if svc.SourcePeer != peerID {
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
}
@@ -865,11 +832,11 @@ func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID,
peer = nil
}
meta := addPeerInfoToEventMeta(service.EventMeta(), peer)
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta)
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)

View File

@@ -11,17 +11,14 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -30,13 +27,13 @@ func TestInitializeServiceForCreate(t *testing.T) {
accountID := "test-account"
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
mgr := &managerImpl{
mgr := &Manager{
clusterDeriver: nil,
}
service := &reverseproxy.Service{
service := &rpservice.Service{
Domain: "example.com",
Auth: reverseproxy.AuthConfig{},
Auth: rpservice.AuthConfig{},
}
err := mgr.initializeServiceForCreate(ctx, accountID, service)
@@ -50,12 +47,12 @@ func TestInitializeServiceForCreate(t *testing.T) {
})
t.Run("verifies session keys are different", func(t *testing.T) {
mgr := &managerImpl{
mgr := &Manager{
clusterDeriver: nil,
}
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
service1 := &rpservice.Service{Domain: "test1.com", Auth: rpservice.AuthConfig{}}
service2 := &rpservice.Service{Domain: "test2.com", Auth: rpservice.AuthConfig{}}
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
@@ -97,7 +94,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
@@ -109,7 +106,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: false,
},
@@ -120,7 +117,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
@@ -146,7 +143,7 @@ func TestCheckDomainAvailable(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
tt.setupMock(mockStore)
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
if tt.expectedError {
@@ -176,7 +173,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
GetServiceByDomain(ctx, accountID, "").
Return(nil, status.Errorf(status.NotFound, "not found"))
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
assert.NoError(t, err)
@@ -189,9 +186,9 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "test.com").
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
assert.Error(t, err)
@@ -209,7 +206,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
GetServiceByDomain(ctx, accountID, "nil.com").
Return(nil, nil)
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
assert.NoError(t, err)
@@ -225,10 +222,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
service := &rpservice.Service{
ID: "service-123",
Domain: "new.com",
Targets: []*reverseproxy.Target{},
Targets: []*rpservice.Target{},
}
// Mock ExecuteInTransaction to execute the function immediately
@@ -247,7 +244,7 @@ func TestPersistNewService(t *testing.T) {
return fn(txMock)
})
mgr := &managerImpl{store: mockStore}
mgr := &Manager{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service)
assert.NoError(t, err)
@@ -258,10 +255,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
service := &rpservice.Service{
ID: "service-123",
Domain: "existing.com",
Targets: []*reverseproxy.Target{},
Targets: []*rpservice.Target{},
}
mockStore.EXPECT().
@@ -270,12 +267,12 @@ func TestPersistNewService(t *testing.T) {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "existing.com").
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
return fn(txMock)
})
mgr := &managerImpl{store: mockStore}
mgr := &Manager{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service)
require.Error(t, err)
@@ -285,21 +282,21 @@ func TestPersistNewService(t *testing.T) {
})
}
func TestPreserveExistingAuthSecrets(t *testing.T) {
mgr := &managerImpl{}
mgr := &Manager{}
t.Run("preserve password when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "hashed-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "",
},
@@ -312,18 +309,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
})
t.Run("preserve pin when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PinAuth: &rpservice.PINAuthConfig{
Enabled: true,
Pin: "hashed-pin",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PinAuth: &rpservice.PINAuthConfig{
Enabled: true,
Pin: "",
},
@@ -336,18 +333,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
})
t.Run("do not preserve when password is provided", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "old-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "new-password",
},
@@ -362,10 +359,10 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
}
func TestPreserveServiceMetadata(t *testing.T) {
mgr := &managerImpl{}
mgr := &Manager{}
existing := &reverseproxy.Service{
Meta: reverseproxy.ServiceMeta{
existing := &rpservice.Service{
Meta: rpservice.ServiceMeta{
CertificateIssuedAt: func() *time.Time { t := time.Now(); return &t }(),
Status: "active",
},
@@ -373,7 +370,7 @@ func TestPreserveServiceMetadata(t *testing.T) {
SessionPublicKey: "public-key",
}
updated := &reverseproxy.Service{
updated := &rpservice.Service{
Domain: "updated.com",
}
@@ -397,31 +394,32 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
IP: net.ParseIP("100.64.0.1"),
}
newEphemeralService := func() *reverseproxy.Service {
return &reverseproxy.Service{
newEphemeralService := func() *rpservice.Service {
return &rpservice.Service{
ID: serviceID,
AccountID: accountID,
Name: "test-service",
Domain: "test.example.com",
Source: reverseproxy.SourceEphemeral,
Source: rpservice.SourceEphemeral,
SourcePeer: ownerPeerID,
}
}
newPermanentService := func() *reverseproxy.Service {
return &reverseproxy.Service{
newPermanentService := func() *rpservice.Service {
return &rpservice.Service{
ID: serviceID,
AccountID: accountID,
Name: "api-service",
Domain: "api.example.com",
Source: reverseproxy.SourcePermanent,
Source: rpservice.SourcePermanent,
}
}
newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer {
t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Hour)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(srv.Close)
return srv
}
@@ -455,10 +453,10 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil)
mgr := &managerImpl{
mgr := &Manager{
store: mockStore,
accountManager: mockAccountMgr,
proxyGRPCServer: newProxyServer(t),
proxyController: NewGRPCProxyController(newProxyServer(t)),
}
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed)
@@ -482,7 +480,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
return fn(txMock)
})
mgr := &managerImpl{
mgr := &Manager{
store: mockStore,
}
@@ -511,7 +509,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
return fn(txMock)
})
mgr := &managerImpl{
mgr := &Manager{
store: mockStore,
}
@@ -553,10 +551,10 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil)
mgr := &managerImpl{
mgr := &Manager{
store: mockStore,
accountManager: mockAccountMgr,
proxyGRPCServer: newProxyServer(t),
proxyController: NewGRPCProxyController(newProxyServer(t)),
}
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceExposeExpired)
@@ -593,10 +591,10 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil)
mgr := &managerImpl{
mgr := &Manager{
store: mockStore,
accountManager: mockAccountMgr,
proxyGRPCServer: newProxyServer(t),
proxyController: NewGRPCProxyController(newProxyServer(t)),
}
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed)
@@ -609,19 +607,6 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
})
}
// noopExtraSettings is a minimal extra_settings.Manager for tests without external integrations.
type noopExtraSettings struct{}
func (n *noopExtraSettings) GetExtraSettings(_ context.Context, _ string) (*types.ExtraSettings, error) {
return &types.ExtraSettings{}, nil
}
func (n *noopExtraSettings) UpdateExtraSettings(_ context.Context, _, _ string, _ *types.ExtraSettings) (bool, error) {
return false, nil
}
var _ extra_settings.Manager = (*noopExtraSettings)(nil)
// testClusterDeriver is a minimal ClusterDeriver that returns a fixed domain list.
type testClusterDeriver struct {
domains []string
@@ -643,7 +628,7 @@ const (
)
// setupIntegrationTest creates a real SQLite store with seeded test data for integration tests.
func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) {
func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
t.Helper()
ctx := context.Background()
@@ -691,30 +676,25 @@ func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) {
require.NoError(t, err)
permsMgr := permissions.NewManager(testStore)
usersMgr := users.NewManager(testStore)
settingsMgr := settings.NewManager(testStore, usersMgr, &noopExtraSettings{}, permsMgr, settings.IdpConfig{})
var storedEvents []activity.Activity
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
storedEvents = append(storedEvents, activityID.(activity.Activity))
},
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
},
}
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Hour)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
mgr := &managerImpl{
mgr := &Manager{
store: testStore,
accountManager: accountMgr,
permissionsManager: permsMgr,
settingsManager: settingsMgr,
proxyGRPCServer: proxySrv,
proxyController: NewGRPCProxyController(proxySrv),
clusterDeriver: &testClusterDeriver{
domains: []string{"test.netbird.io"},
},
@@ -788,7 +768,7 @@ func Test_validateExposePermission(t *testing.T) {
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountSettings(gomock.Any(), gomock.Any(), testAccountID).Return(nil, errors.New("store error"))
mgr := &managerImpl{store: mockStore}
mgr := &Manager{store: mockStore}
err := mgr.validateExposePermission(ctx, testAccountID, testPeerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "get account settings")
@@ -801,7 +781,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
t.Run("creates service with random domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
@@ -816,7 +796,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err)
assert.Equal(t, resp.Domain, persisted.Domain)
assert.Equal(t, reverseproxy.SourceEphemeral, persisted.Source, "source should be ephemeral")
assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
assert.Equal(t, testPeerID, persisted.SourcePeer, "source peer should be set")
assert.NotNil(t, persisted.Meta.LastRenewedAt, "last renewed should be set")
})
@@ -824,7 +804,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
t.Run("creates service with custom domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 80,
Protocol: "http",
Domain: "example.com",
@@ -845,7 +825,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
@@ -858,7 +838,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
t.Run("validates request fields", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 0,
Protocol: "http",
}
@@ -872,67 +852,67 @@ func TestCreateServiceFromPeer(t *testing.T) {
func TestExposeServiceRequestValidate(t *testing.T) {
tests := []struct {
name string
req reverseproxy.ExposeServiceRequest
req rpservice.ExposeServiceRequest
wantErr string
}{
{
name: "valid http request",
req: reverseproxy.ExposeServiceRequest{Port: 8080, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"},
wantErr: "",
},
{
name: "valid https request with pin",
req: reverseproxy.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
wantErr: "",
},
{
name: "port zero rejected",
req: reverseproxy.ExposeServiceRequest{Port: 0, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "negative port rejected",
req: reverseproxy.ExposeServiceRequest{Port: -1, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "port above 65535 rejected",
req: reverseproxy.ExposeServiceRequest{Port: 65536, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "unsupported protocol",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
wantErr: "unsupported protocol",
},
{
name: "invalid pin format",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
wantErr: "invalid pin",
},
{
name: "pin too short",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
wantErr: "invalid pin",
},
{
name: "valid 6-digit pin",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
wantErr: "",
},
{
name: "empty user group name",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
wantErr: "user group name cannot be empty",
},
{
name: "invalid name prefix",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
wantErr: "invalid name prefix",
},
{
name: "valid name prefix",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
wantErr: "",
},
}
@@ -950,7 +930,7 @@ func TestExposeServiceRequestValidate(t *testing.T) {
}
t.Run("nil receiver", func(t *testing.T) {
var req *reverseproxy.ExposeServiceRequest
var req *rpservice.ExposeServiceRequest
err := req.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "request cannot be nil")
@@ -964,7 +944,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
// First create a service
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
@@ -983,7 +963,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
t.Run("expire uses correct activity", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
@@ -1001,7 +981,7 @@ func TestStopServiceFromPeer(t *testing.T) {
t.Run("stops service by domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
@@ -1020,7 +1000,7 @@ func TestDeleteService_UntracksEphemeralExpose(t *testing.T) {
ctx := context.Background()
mgr, _ := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
@@ -1038,7 +1018,7 @@ func TestDeleteService_UntracksEphemeralExpose(t *testing.T) {
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be untracked after API delete")
// A new expose should succeed (not blocked by stale tracking)
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9090,
Protocol: "http",
})
@@ -1050,7 +1030,7 @@ func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
for i := range 3 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
@@ -1071,7 +1051,7 @@ func TestRenewServiceFromPeer(t *testing.T) {
t.Run("renews tracked expose", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})

View File

@@ -1,4 +1,4 @@
package reverseproxy
package service
import (
"crypto/rand"
@@ -29,15 +29,15 @@ const (
Delete Operation = "delete"
)
type ProxyStatus string
type Status string
const (
StatusPending ProxyStatus = "pending"
StatusActive ProxyStatus = "active"
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
StatusCertificatePending ProxyStatus = "certificate_pending"
StatusCertificateFailed ProxyStatus = "certificate_failed"
StatusError ProxyStatus = "error"
StatusPending Status = "pending"
StatusActive Status = "active"
StatusTunnelNotCreated Status = "tunnel_not_created"
StatusCertificatePending Status = "certificate_pending"
StatusCertificateFailed Status = "certificate_failed"
StatusError Status = "error"
TargetTypePeer = "peer"
TargetTypeHost = "host"

View File

@@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -134,7 +134,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if s.Config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil {
log.Fatalf("failed to create certificate manager: %v", err)
log.Fatalf("failed to create certificate service: %v", err)
}
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
@@ -152,10 +152,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}
reverseProxyMgr := s.ReverseProxyManager()
srv.SetReverseProxyManager(reverseProxyMgr)
if reverseProxyMgr != nil {
reverseProxyMgr.StartExposeReaper(context.Background())
serviceMgr := s.ServiceManager()
srv.SetReverseProxyManager(serviceMgr)
if serviceMgr != nil {
serviceMgr.StartExposeReaper(context.Background())
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
@@ -168,9 +168,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager())
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ReverseProxyManager())
proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())
})
return proxyService
})
@@ -193,7 +194,10 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
return Create(s, func() *nbgrpc.OneTimeTokenStore {
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create proxy token store: %v", err)
}
log.Info("One-time token store initialized for proxy authentication")
return tokenStore
})

View File

@@ -6,6 +6,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -106,6 +108,12 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
})
}
func (s *BaseServer) ServiceProxyController() service.ProxyController {
return Create(s, func() service.ProxyController {
return nbreverseproxy.NewGRPCProxyController(s.ReverseProxyGRPCServer())
})
}
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
return Create(s, func() *server.AccountRequestBuffer {
return server.NewAccountRequestBuffer(context.Background(), s.Store())

View File

@@ -8,9 +8,11 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
@@ -99,11 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
if err != nil {
log.Fatalf("failed to create account manager: %v", err)
log.Fatalf("failed to create account service: %v", err)
}
s.AfterInit(func(s *BaseServer) {
accountManager.SetServiceManager(s.ReverseProxyManager())
accountManager.SetServiceManager(s.ServiceManager())
})
return accountManager
@@ -114,28 +116,28 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager {
var idpManager idp.Manager
var err error
// Use embedded IdP manager if embedded Dex is configured and enabled.
// Use embedded IdP service if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured.
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil {
log.Fatalf("failed to create embedded IDP manager: %v", err)
log.Fatalf("failed to create embedded IDP service: %v", err)
}
return idpManager
}
// Fall back to external IdP manager
// Fall back to external IdP service
if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil {
log.Fatalf("failed to create IDP manager: %v", err)
log.Fatalf("failed to create IDP service: %v", err)
}
}
return idpManager
})
}
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
return nil
@@ -162,7 +164,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager())
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
})
}
@@ -190,15 +192,21 @@ func (s *BaseServer) RecordsManager() records.Manager {
})
}
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager {
return Create(s, func() reverseproxy.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.SettingsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
})
}
func (s *BaseServer) ProxyManager() proxy.Manager {
return Create(s, func() proxy.Manager {
return proxymanager.NewManager(s.Store())
})
}
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager())
m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager())
return &m
})
}

View File

@@ -157,7 +157,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
// Eagerly create the gRPC server so that all AfterInit hooks are registered
// before we iterate them. Lazy creation after the loop would miss hooks
// registered during GRPCServer() construction (e.g., SetProxyManager).
// registered during GRPCServer() construction (e.g., SetServiceManager).
s.GRPCServer()
for _, fn := range s.afterInit {

View File

@@ -10,7 +10,7 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbContext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
@@ -39,7 +39,7 @@ func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage)
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &reverseproxy.ExposeServiceRequest{
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{
NamePrefix: exposeReq.NamePrefix,
Port: int(exposeReq.Port),
Protocol: exposeProtocolToString(exposeReq.Protocol),
@@ -167,14 +167,14 @@ func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key
return accountID, peer, nil
}
func (s *Server) getReverseProxyManager() reverseproxy.Manager {
func (s *Server) getReverseProxyManager() rpservice.Manager {
s.reverseProxyMu.RLock()
defer s.reverseProxyMu.RUnlock()
return s.reverseProxyManager
}
// SetReverseProxyManager sets the reverse proxy manager on the server.
func (s *Server) SetReverseProxyManager(mgr reverseproxy.Manager) {
func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) {
s.reverseProxyMu.Lock()
defer s.reverseProxyMu.Unlock()
s.reverseProxyManager = mgr

View File

@@ -1,28 +1,23 @@
package grpc
import (
"context"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
)
// OneTimeTokenStore manages short-lived, single-use authentication tokens
// for proxy-to-management RPC authentication. Tokens are generated when
// a service is created and must be used exactly once by the proxy
// to authenticate a subsequent RPC call.
type OneTimeTokenStore struct {
tokens map[string]*tokenMetadata
mu sync.RWMutex
cleanup *time.Ticker
cleanupDone chan struct{}
}
// tokenMetadata stores information about a one-time token
type tokenMetadata struct {
ServiceID string
AccountID string
@@ -30,20 +25,24 @@ type tokenMetadata struct {
CreatedAt time.Time
}
// NewOneTimeTokenStore creates a new token store with automatic cleanup
// of expired tokens. The cleanupInterval determines how often expired
// tokens are removed from memory.
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
store := &OneTimeTokenStore{
tokens: make(map[string]*tokenMetadata),
cleanup: time.NewTicker(cleanupInterval),
cleanupDone: make(chan struct{}),
// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
type OneTimeTokenStore struct {
cache *cache.Cache[string]
ctx context.Context
}
// NewOneTimeTokenStore creates a token store with automatic backend selection
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
}
// Start background cleanup goroutine
go store.cleanupExpired()
return store
return &OneTimeTokenStore{
cache: cache.New[string](cacheStore),
ctx: ctx,
}, nil
}
// GenerateToken creates a new cryptographically secure one-time token
@@ -52,25 +51,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
//
// Returns the generated token string or an error if random generation fails.
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
// Generate 32 bytes (256 bits) of cryptographically secure random data
randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil {
return "", fmt.Errorf("failed to generate random token: %w", err)
}
// Encode as URL-safe base64 for easy transmission in gRPC
token := base64.URLEncoding.EncodeToString(randomBytes)
hashedToken := hashToken(token)
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[token] = &tokenMetadata{
metadata := &tokenMetadata{
ServiceID: serviceID,
AccountID: accountID,
ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(),
}
metadataJSON, err := json.Marshal(metadata)
if err != nil {
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
}
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
return "", fmt.Errorf("failed to store token: %w", err)
}
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
serviceID, accountID, ttl)
@@ -88,80 +92,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
// - Account ID doesn't match
// - Reverse proxy ID doesn't match
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
s.mu.Lock()
defer s.mu.Unlock()
hashedToken := hashToken(token)
metadata, exists := s.tokens[token]
if !exists {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)",
serviceID, accountID)
metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
if err != nil {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
return fmt.Errorf("invalid token")
}
// Check expiration
metadata := &tokenMetadata{}
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
return fmt.Errorf("invalid token metadata")
}
if time.Now().After(metadata.ExpiresAt) {
delete(s.tokens, token)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
serviceID, accountID)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
return fmt.Errorf("token expired")
}
// Validate account ID using constant-time comparison (prevents timing attacks)
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)",
metadata.AccountID, accountID)
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
return fmt.Errorf("account ID mismatch")
}
// Validate service ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)",
metadata.ServiceID, serviceID)
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
return fmt.Errorf("service ID mismatch")
}
// Delete token immediately to enforce single-use
delete(s.tokens, token)
if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
}
log.Infof("Token validated and consumed for proxy %s in account %s",
serviceID, accountID)
log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
return nil
}
// cleanupExpired removes expired tokens in the background to prevent memory leaks
func (s *OneTimeTokenStore) cleanupExpired() {
for {
select {
case <-s.cleanup.C:
s.mu.Lock()
now := time.Now()
removed := 0
for token, metadata := range s.tokens {
if now.After(metadata.ExpiresAt) {
delete(s.tokens, token)
removed++
}
}
if removed > 0 {
log.Debugf("Cleaned up %d expired one-time tokens", removed)
}
s.mu.Unlock()
case <-s.cleanupDone:
return
}
}
}
// Close stops the cleanup goroutine and releases resources
func (s *OneTimeTokenStore) Close() {
s.cleanup.Stop()
close(s.cleanupDone)
}
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
func (s *OneTimeTokenStore) GetTokenCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.tokens)
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}

View File

@@ -24,8 +24,9 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
@@ -58,14 +59,20 @@ type ProxyServiceServer struct {
// Map of connected proxies: proxy_id -> proxy connection
connectedProxies sync.Map
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
// Channel for broadcasting reverse proxy updates to all proxies
updatesChan chan *proto.ProxyMapping
// Manager for access logs
accessLogManager accesslogs.Manager
// Manager for reverse proxy operations
reverseProxyManager reverseproxy.Manager
serviceManager rpservice.Manager
// ProxyController for service updates and cluster management
proxyController rpservice.ProxyController
// Manager for proxy connections
proxyManager proxy.Manager
// Manager for peers
peersManager peers.Manager
@@ -104,7 +111,7 @@ type proxyConnection struct {
}
// NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
accessLogManager: accessLogMgr,
@@ -112,9 +119,11 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
tokenStore: tokenStore,
peersManager: peersManager,
usersManager: usersManager,
proxyManager: proxyMgr,
pkceCleanupCancel: cancel,
}
go s.cleanupPKCEVerifiers(ctx)
go s.cleanupStaleProxies(ctx)
return s
}
@@ -138,13 +147,33 @@ func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
}
}
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil {
log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err)
}
}
}
}
// Close stops background goroutines.
func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel()
}
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) {
s.reverseProxyManager = manager
func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) {
s.serviceManager = manager
}
func (s *ProxyServiceServer) SetProxyController(proxyController rpservice.ProxyController) {
s.proxyController = proxyController
}
// GetMappingUpdate handles the control stream with proxy clients
@@ -179,7 +208,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
s.connectedProxies.Store(proxyID, conn)
s.addToCluster(conn.address, proxyID)
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
}
log.WithFields(log.Fields{
"proxy_id": proxyID,
"address": proxyAddress,
@@ -187,8 +224,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
"total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster")
defer func() {
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
s.connectedProxies.Delete(proxyID)
s.removeFromCluster(conn.address, proxyID)
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
cancel()
log.Infof("Proxy %s disconnected", proxyID)
}()
@@ -200,6 +244,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
errChan := make(chan error, 2)
go s.sender(conn, errChan)
// Start heartbeat goroutine
go s.heartbeat(connCtx, proxyID)
select {
case err := <-errChan:
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
@@ -208,10 +255,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
}
}
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
}
case <-ctx.Done():
return
}
}
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only services matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return fmt.Errorf("get services from store: %w", err)
}
@@ -220,7 +284,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return fmt.Errorf("proxy address is invalid")
}
var filtered []*reverseproxy.Service
var filtered []*rpservice.Service
for _, service := range services {
if !service.Enabled {
continue
@@ -255,7 +319,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{
service.ToProtoMapping(
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy.
rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
token,
s.GetOIDCValidationConfig(),
),
@@ -389,61 +453,47 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
return urls
}
// addToCluster registers a proxy in a cluster.
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
}
// removeFromCluster removes a proxy from a cluster.
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
}
}
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.GetMappingUpdateResponse, clusterAddr string) {
func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
updateResponse := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{update},
}
if clusterAddr == "" {
s.SendServiceUpdate(update)
s.SendServiceUpdate(updateResponse)
return
}
proxySet, ok := s.clusterProxies.Load(clusterAddr)
if !ok {
log.Debugf("No proxies connected for cluster %s", clusterAddr)
if s.proxyController == nil {
log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr)
return
}
proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr)
if len(proxyIDs) == 0 {
log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr)
return
}
log.Debugf("Sending service update to cluster %s", clusterAddr)
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxyID := key.(string)
for _, proxyID := range proxyIDs {
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection)
msg := s.perProxyMessage(update, proxyID)
msg := s.perProxyMessage(updateResponse, proxyID)
if msg == nil {
return true
continue
}
select {
case conn.sendChan <- msg:
log.Debugf("Sent service update to proxy %s in cluster %s", proxyID, clusterAddr)
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default:
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
}
}
return true
})
}
}
// perProxyMessage returns a copy of update with a fresh one-time token for
@@ -490,35 +540,8 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
}
}
// GetAvailableClusters returns information about all connected proxy clusters.
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
clusterCounts := make(map[string]int)
s.clusterProxies.Range(func(key, value interface{}) bool {
clusterAddr := key.(string)
proxySet := value.(*sync.Map)
count := 0
proxySet.Range(func(_, _ interface{}) bool {
count++
return true
})
if count > 0 {
clusterCounts[clusterAddr] = count
}
return true
})
clusters := make([]ClusterInfo, 0, len(clusterCounts))
for addr, count := range clusterCounts {
clusters = append(clusters, ClusterInfo{
Address: addr,
ConnectedProxies: count,
})
}
return clusters
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
@@ -537,7 +560,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
}, nil
}
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) {
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) {
switch v := req.GetRequest().(type) {
case *proto.AuthenticateRequest_Pin:
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
@@ -548,7 +571,7 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
}
}
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) {
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
return false, "", ""
@@ -562,7 +585,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri
return true, "pin-user", proxyauth.MethodPIN
}
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) {
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
return false, "", ""
@@ -584,7 +607,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
}
}
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) {
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
if !authenticated || service.SessionPrivateKey == "" {
return "", nil
}
@@ -624,7 +647,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
}
if certificateIssued {
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
}
@@ -636,7 +659,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
internalStatus := protoStatusToInternal(protoStatus)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to update service status")
return nil, status.Errorf(codes.Internal, "update service status: %v", err)
}
@@ -651,22 +674,22 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
}
// protoStatusToInternal maps proto status to internal status
func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus {
func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
switch protoStatus {
case proto.ProxyStatus_PROXY_STATUS_PENDING:
return reverseproxy.StatusPending
return rpservice.StatusPending
case proto.ProxyStatus_PROXY_STATUS_ACTIVE:
return reverseproxy.StatusActive
return rpservice.StatusActive
case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED:
return reverseproxy.StatusTunnelNotCreated
return rpservice.StatusTunnelNotCreated
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING:
return reverseproxy.StatusCertificatePending
return rpservice.StatusCertificatePending
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED:
return reverseproxy.StatusCertificateFailed
return rpservice.StatusCertificateFailed
case proto.ProxyStatus_PROXY_STATUS_ERROR:
return reverseproxy.StatusError
return rpservice.StatusError
default:
return reverseproxy.StatusError
return rpservice.StatusError
}
}
@@ -731,7 +754,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
}
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId())
services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil {
log.WithContext(ctx).Errorf("failed to get account services: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
@@ -794,8 +817,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig {
// GetOIDCValidationConfig returns the OIDC configuration for token validation
// in the format needed by ToProtoMapping.
func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig {
return reverseproxy.OIDCValidationConfig{
func (s *ProxyServiceServer) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
return rpservice.OIDCValidationConfig{
Issuer: s.oidcConfig.Issuer,
Audiences: []string{s.oidcConfig.Audience},
KeysLocation: s.oidcConfig.KeysLocation,
@@ -854,12 +877,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return "", fmt.Errorf("get services: %w", err)
}
var service *reverseproxy.Service
var service *rpservice.Service
for _, svc := range services {
if svc.Domain == domain {
service = svc
@@ -925,8 +948,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
}
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID)
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
services, err := s.serviceManager.GetAccountServices(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account services: %w", err)
}
@@ -1047,8 +1070,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil
}
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) {
services, err := s.reverseProxyManager.GetGlobalServices(ctx)
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil {
return nil, fmt.Errorf("get services: %w", err)
}
@@ -1062,7 +1085,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
return nil, fmt.Errorf("service not found for domain: %s", domain)
}
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error {
func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil
}

View File

@@ -8,12 +8,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/types"
)
type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.Service
proxiesByAccount map[string][]*service.Service
err error
}
@@ -21,31 +21,31 @@ func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, account
return nil
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
if m.err != nil {
return nil, m.err
}
return m.proxiesByAccount[accountID], nil
}
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return []*reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
return []*service.Service{}, nil
}
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
@@ -56,7 +56,7 @@ func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, ac
return nil
}
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error {
return nil
}
@@ -68,16 +68,16 @@ func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID,
return nil
}
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) {
return &service.Service{}, nil
}
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
return &reverseproxy.ExposeServiceResponse{}, nil
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return &service.ExposeServiceResponse{}, nil
}
func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
@@ -111,7 +111,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name string
domain string
userID string
proxiesByAccount map[string][]*reverseproxy.Service
proxiesByAccount map[string][]*service.Service
users map[string]*types.User
proxyErr error
userErr error
@@ -122,7 +122,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not found",
domain: "app.example.com",
userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
users: map[string]*types.User{},
@@ -133,7 +133,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy not found in user's account",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{},
proxiesByAccount: map[string][]*service.Service{},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
@@ -144,7 +144,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "proxy exists in different account - not accessible",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
},
users: map[string]*types.User{
@@ -157,8 +157,8 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "no bearer auth configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
@@ -169,12 +169,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth disabled - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{Enabled: false},
},
}},
},
@@ -187,12 +187,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{},
},
@@ -208,12 +208,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user not in allowed groups",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
@@ -230,12 +230,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in one of the allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
@@ -251,12 +251,12 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "user in all allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
@@ -284,10 +284,10 @@ func TestValidateUserGroupAccess(t *testing.T) {
name: "multiple proxies in account - finds correct one",
domain: "app2.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {
{Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
{Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"},
},
},
@@ -301,7 +301,7 @@ func TestValidateUserGroupAccess(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.proxyErr,
},
@@ -328,7 +328,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name string
accountID string
domain string
proxiesByAccount map[string][]*reverseproxy.Service
proxiesByAccount map[string][]*service.Service
err error
expectProxy bool
expectErr bool
@@ -337,7 +337,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy found",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {
{Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"},
@@ -350,7 +350,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "proxy not found in account",
accountID: "account1",
domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
proxiesByAccount: map[string][]*service.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
expectProxy: false,
@@ -360,7 +360,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
name: "empty proxy list for account",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{},
proxiesByAccount: map[string][]*service.Service{},
expectProxy: false,
expectErr: true,
},
@@ -378,7 +378,7 @@ func TestGetAccountProxyByDomain(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.err,
},

View File

@@ -1,19 +1,73 @@
package grpc
import (
"context"
"crypto/rand"
"encoding/base64"
"strings"
"sync"
"testing"
"time"
"sync"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/management/proto"
)
type testProxyController struct {
mu sync.Mutex
clusterProxies map[string]map[string]struct{}
}
func newTestProxyController() *testProxyController {
return &testProxyController{
clusterProxies: make(map[string]map[string]struct{}),
}
}
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
}
func (c *testProxyController) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
return rpservice.OIDCValidationConfig{}
}
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, clusterAddr, proxyID string) error {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.clusterProxies[clusterAddr]; !ok {
c.clusterProxies[clusterAddr] = make(map[string]struct{})
}
c.clusterProxies[clusterAddr][proxyID] = struct{}{}
return nil
}
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clusterAddr, proxyID string) error {
c.mu.Lock()
defer c.mu.Unlock()
if proxies, ok := c.clusterProxies[clusterAddr]; ok {
delete(proxies, proxyID)
}
return nil
}
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
c.mu.Lock()
defer c.mu.Unlock()
proxies, ok := c.clusterProxies[clusterAddr]
if !ok {
return nil
}
result := make([]string, 0, len(proxies))
for id := range proxies {
result = append(result, id)
}
return result
}
// registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
@@ -25,8 +79,7 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
}
s.connectedProxies.Store(proxyID, conn)
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
_ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID)
return ch
}
@@ -41,12 +94,13 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda
}
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
const numProxies = 3
@@ -67,11 +121,7 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
},
}
update := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{mapping},
}
s.SendServiceUpdateToCluster(update, cluster)
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
tokens := make([]string, numProxies)
for i, ch := range channels {
@@ -101,12 +151,13 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
}
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
const cluster = "proxy.example.com"
ch1 := registerFakeProxy(s, "proxy-a", cluster)
@@ -119,11 +170,7 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
Domain: "test.example.com",
}
update := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{mapping},
}
s.SendServiceUpdateToCluster(update, cluster)
s.SendServiceUpdateToCluster(context.Background(), mapping, cluster)
resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2)
@@ -135,18 +182,16 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
// Delete operations should not generate tokens
assert.Empty(t, resp1.Mapping[0].AuthToken)
assert.Empty(t, resp2.Mapping[0].AuthToken)
// No tokens should have been created
assert.Equal(t, 0, tokenStore.GetTokenCount())
}
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
}
s.SetProxyController(newTestProxyController())
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")

View File

@@ -26,7 +26,7 @@ import (
"github.com/netbirdio/netbird/shared/management/client/common"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/job"
@@ -82,7 +82,7 @@ type Server struct {
syncLimEnabled bool
syncLim int32
reverseProxyManager reverseproxy.Manager
reverseProxyManager rpservice.Manager
reverseProxyMu sync.RWMutex
}

View File

@@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -34,11 +34,15 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err)
proxyManager := &testValidateSessionProxyManager{store: testStore}
serviceManager := &testValidateSessionServiceManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore}
proxyManager := &testValidateSessionProxyManager{}
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
proxyService.SetProxyManager(proxyManager)
tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager)
proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore)
@@ -54,7 +58,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
pubKey, privKey := generateSessionKeyPair(t)
testProxy := &reverseproxy.Service{
testProxy := &service.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
@@ -62,15 +66,15 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
},
},
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{
restrictedProxy := &service.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
@@ -78,8 +82,8 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store)
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
@@ -239,79 +243,101 @@ func TestValidateSession_MissingToken(t *testing.T) {
assert.Contains(t, resp.DeniedReason, "missing")
}
type testValidateSessionProxyManager struct {
type testValidateSessionServiceManager struct {
store store.Store
}
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) DeleteAllServices(_ context.Context, _, _ string) error {
func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
func (m *testValidateSessionProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
func (m *testValidateSessionServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) StartExposeReaper(_ context.Context) {}
func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {}
type testValidateSessionProxyManager struct{}
func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
type testValidateSessionUsersManager struct {
store store.Store

View File

@@ -15,7 +15,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth"
@@ -83,9 +83,9 @@ type DefaultAccountManager struct {
requestBuffer *AccountRequestBuffer
proxyController port_forwarding.Controller
settingsManager settings.Manager
reverseProxyManager reverseproxy.Manager
proxyController port_forwarding.Controller
settingsManager settings.Manager
serviceManager service.Manager
// config contains the management server configuration
config *nbconfig.Config
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
var _ account.Manager = (*DefaultAccountManager)(nil)
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
am.reverseProxyManager = serviceManager
func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
am.serviceManager = serviceManager
}
func isUniqueConstraintError(err error) bool {
@@ -395,7 +395,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
}
if reloadReverseProxy {
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
}
}
@@ -730,7 +730,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
}
err = am.reverseProxyManager.DeleteAllServices(ctx, accountID, userID)
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
}

View File

@@ -6,7 +6,7 @@ import (
"net/netip"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/auth"
nbdns "github.com/netbirdio/netbird/dns"
@@ -140,5 +140,5 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
SetServiceManager(serviceManager reverseproxy.Manager)
SetServiceManager(serviceManager service.Manager)
}

View File

@@ -27,8 +27,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"

View File

@@ -703,7 +703,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
@@ -73,7 +73,7 @@ const (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router)
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
}
// Register OAuth callback handler for proxy authentication

View File

@@ -18,8 +18,8 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -190,7 +190,8 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcServer := newFakeOIDCServer()
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
usersManager := users.NewManager(testStore)
@@ -210,7 +211,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
usersManager,
)
proxyService.SetProxyManager(&testServiceManager{store: testStore})
proxyService.SetServiceManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService, nil)

View File

@@ -12,7 +12,8 @@ import (
"github.com/netbirdio/management-integrations/integrations"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -91,12 +92,16 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager)
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager)
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, settingsManager, proxyServiceServer, domainManager)
proxyServiceServer.SetProxyManager(reverseProxyManager)
am.SetServiceManager(reverseProxyManager)
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
proxyMgr := proxymanager.NewManager(store)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, reverseproxymanager.NewGRPCProxyController(proxyServiceServer), domainManager)
proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
@@ -114,7 +119,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/types"
@@ -358,12 +358,12 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
}
servicesTargets += len(service.Targets)
switch reverseproxy.ProxyStatus(service.Meta.Status) {
case reverseproxy.StatusActive:
switch rpservice.Status(service.Meta.Status) {
case rpservice.StatusActive:
servicesStatusActive++
case reverseproxy.StatusPending:
case rpservice.StatusPending:
servicesStatusPending++
case reverseproxy.StatusError, reverseproxy.StatusCertificateFailed, reverseproxy.StatusTunnelNotCreated:
case rpservice.StatusError, rpservice.StatusCertificateFailed, rpservice.StatusTunnelNotCreated:
servicesStatusError++
}

View File

@@ -6,7 +6,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -116,29 +116,29 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
},
},
},
Services: []*reverseproxy.Service{
Services: []*rpservice.Service{
{
ID: "svc1",
Enabled: true,
Targets: []*reverseproxy.Target{
Targets: []*rpservice.Target{
{TargetType: "peer"},
{TargetType: "host"},
},
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{Enabled: true},
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
},
Meta: reverseproxy.ServiceMeta{Status: string(reverseproxy.StatusActive)},
Meta: rpservice.ServiceMeta{Status: string(rpservice.StatusActive)},
},
{
ID: "svc2",
Enabled: false,
Targets: []*reverseproxy.Target{
Targets: []*rpservice.Target{
{TargetType: "domain"},
},
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: true},
Auth: rpservice.AuthConfig{
BearerAuth: &rpservice.BearerAuthConfig{Enabled: true},
},
Meta: reverseproxy.ServiceMeta{Status: string(reverseproxy.StatusPending)},
Meta: rpservice.ServiceMeta{Status: string(rpservice.StatusPending)},
},
},
},

View File

@@ -12,7 +12,7 @@ import (
"google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
@@ -148,7 +148,7 @@ type MockAccountManager struct {
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) {
// Mock implementation - no-op
}

View File

@@ -7,7 +7,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
@@ -33,23 +33,23 @@ type Manager interface {
}
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
groupsManager groups.Manager
accountManager account.Manager
reverseProxyManager reverseproxy.Manager
store store.Store
permissionsManager permissions.Manager
groupsManager groups.Manager
accountManager account.Manager
serviceManager service.Manager
}
type mockManager struct {
}
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager {
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
groupsManager: groupsManager,
accountManager: accountManager,
reverseProxyManager: reverseproxyManager,
store: store,
permissionsManager: permissionsManager,
groupsManager: groupsManager,
accountManager: accountManager,
serviceManager: reverseproxyManager,
}
}
@@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
go func() {
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
if err != nil {
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
}
@@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return status.NewPermissionDeniedError()
}
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err)
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -31,8 +31,8 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err)
@@ -54,8 +54,8 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.Error(t, err)
@@ -76,8 +76,8 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.NoError(t, err)
@@ -98,8 +98,8 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.Error(t, err)
@@ -123,8 +123,8 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err)
@@ -147,8 +147,8 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err)
@@ -176,9 +176,9 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.NoError(t, err)
@@ -205,8 +205,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -234,8 +234,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -262,8 +262,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -294,9 +294,9 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.NoError(t, err)
@@ -329,8 +329,8 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -361,8 +361,8 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -392,8 +392,8 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -416,9 +416,9 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err)
@@ -440,8 +440,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err)

View File

@@ -493,7 +493,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var settings *types.Settings
var eventsToStore []func()
serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID)
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err)
}

View File

@@ -28,9 +28,10 @@ import (
"gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -131,8 +132,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -2075,7 +2076,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
return checks, nil
}
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key
@@ -2090,8 +2091,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
return nil, err
}
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) {
var s reverseproxy.Service
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
var s rpservice.Service
var auth []byte
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
@@ -2121,7 +2122,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
}
}
s.Meta = reverseproxy.ServiceMeta{}
s.Meta = rpservice.ServiceMeta{}
if createdAt.Valid {
s.Meta.CreatedAt = createdAt.Time
}
@@ -2142,7 +2143,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
s.SessionPublicKey = sessionPublicKey.String
}
s.Targets = []*reverseproxy.Target{}
s.Targets = []*rpservice.Target{}
return &s, nil
})
if err != nil {
@@ -2154,7 +2155,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
}
serviceIDs := make([]string, len(services))
serviceMap := make(map[string]*reverseproxy.Service)
serviceMap := make(map[string]*rpservice.Service)
for i, s := range services {
serviceIDs[i] = s.ID
serviceMap[s.ID] = s
@@ -2165,8 +2166,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
return nil, err
}
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) {
var t reverseproxy.Target
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) {
var t rpservice.Target
var path sql.NullString
err := row.Scan(
&t.ID,
@@ -4838,7 +4839,7 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
return peerID, nil
}
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error {
serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err)
@@ -4852,16 +4853,19 @@ func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Serv
return nil
}
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error {
serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err)
}
// Create target type instance outside transaction to avoid variable shadowing
targetType := &rpservice.Target{}
// Use a transaction to ensure atomic updates of the service and its targets
err := s.db.Transaction(func(tx *gorm.DB) error {
// Delete existing targets
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil {
return err
}
@@ -4882,7 +4886,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Serv
}
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID)
result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete service from store")
@@ -4895,13 +4899,13 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
return nil
}
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var service *reverseproxy.Service
var service *rpservice.Service
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -4919,30 +4923,8 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
return service, nil
}
func (s *SqlStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var serviceList []*reverseproxy.Service
result := tx.Find(&serviceList, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get services from store")
}
for _, service := range serviceList {
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt service data: %w", err)
}
}
return serviceList, nil
}
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
var service *reverseproxy.Service
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
var service *rpservice.Service
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -4960,13 +4942,13 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str
return service, nil
}
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var serviceList []*reverseproxy.Service
var serviceList []*rpservice.Service
result := tx.Find(&serviceList)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
@@ -4982,13 +4964,13 @@ func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength
return serviceList, nil
}
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var serviceList []*reverseproxy.Service
var serviceList []*rpservice.Service
result := tx.Find(&serviceList, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
@@ -5216,13 +5198,13 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
return query
}
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var target *reverseproxy.Target
var target *rpservice.Target
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -5235,3 +5217,65 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
return target, nil
}
// SaveProxy saves or updates a proxy in the database
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
result := s.db.WithContext(ctx).Save(p)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to save proxy")
}
return nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", time.Now())
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
return status.Errorf(status.Internal, "failed to update proxy heartbeat")
}
return nil
}
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
var addresses []string
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses")
}
return addresses, nil
}
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)
result := s.db.WithContext(ctx).
Where("last_seen < ?", cutoffTime).
Delete(&proxy.Proxy{})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error)
return status.Errorf(status.Internal, "failed to cleanup stale proxies")
}
if result.RowsAffected > 0 {
log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected)
}
return nil
}

View File

@@ -20,7 +20,7 @@ import (
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -264,7 +264,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{},
&types.AccountOnboarding{}, &service.Service{}, &service.Target{},
}
for i := len(models) - 1; i >= 0; i-- {

View File

@@ -25,9 +25,10 @@ import (
"gorm.io/gorm"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -252,14 +253,13 @@ type Store interface {
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
CreateService(ctx context.Context, service *reverseproxy.Service) error
UpdateService(ctx context.Context, service *reverseproxy.Service) error
CreateService(ctx context.Context, service *rpservice.Service) error
UpdateService(ctx context.Context, service *rpservice.Service) error
DeleteService(ctx context.Context, accountID, serviceID string) error
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error)
GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
@@ -271,9 +271,13 @@ type Store interface {
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error)
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
// GetCustomDomainsCounts returns the total and validated custom domain counts.
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
}

View File

@@ -12,9 +12,10 @@ import (
gomock "github.com/golang/mock/gomock"
dns "github.com/netbirdio/netbird/dns"
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
zones "github.com/netbirdio/netbird/management/internals/modules/zones"
records "github.com/netbirdio/netbird/management/internals/modules/zones/records"
types "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -150,6 +151,20 @@ func (mr *MockStoreMockRecorder) ApproveAccountPeers(ctx, accountID interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID)
}
// CleanupStaleProxies mocks base method.
func (m *MockStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupStaleProxies", ctx, inactivityDuration)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupStaleProxies indicates an expected call of CleanupStaleProxies.
func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
}
// Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper()
@@ -293,7 +308,7 @@ func (mr *MockStoreMockRecorder) CreatePolicy(ctx, policy interface{}) *gomock.C
}
// CreateService mocks base method.
func (m *MockStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
func (m *MockStore) CreateService(ctx context.Context, service *service.Service) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateService", ctx, service)
ret0, _ := ret[0].(error)
@@ -1095,10 +1110,10 @@ func (mr *MockStoreMockRecorder) GetAccountRoutes(ctx, lockStrength, accountID i
}
// GetAccountServices mocks base method.
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1109,21 +1124,6 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
}
// GetServicesByAccountID mocks base method.
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
}
// GetAccountSettings mocks base method.
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
m.ctrl.T.Helper()
@@ -1214,6 +1214,21 @@ func (mr *MockStoreMockRecorder) GetAccountsCounter(ctx interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx)
}
// GetActiveProxyClusterAddresses mocks base method.
func (m *MockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddresses", ctx)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveProxyClusterAddresses indicates an expected call of GetActiveProxyClusterAddresses.
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
}
// GetAllAccounts mocks base method.
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
m.ctrl.T.Helper()
@@ -1828,10 +1843,10 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
}
// GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
ret0, _ := ret[0].(*reverseproxy.Service)
ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1843,10 +1858,10 @@ func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain inter
}
// GetServiceByID mocks base method.
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID)
ret0, _ := ret[0].(*reverseproxy.Service)
ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1858,10 +1873,10 @@ func (mr *MockStoreMockRecorder) GetServiceByID(ctx, lockStrength, accountID, se
}
// GetServiceTargetByTargetID mocks base method.
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*reverseproxy.Target, error) {
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*service.Target, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID)
ret0, _ := ret[0].(*reverseproxy.Target)
ret0, _ := ret[0].(*service.Target)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1889,10 +1904,10 @@ func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock
}
// GetServices mocks base method.
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -2567,6 +2582,20 @@ func (mr *MockStoreMockRecorder) SavePostureChecks(ctx, postureCheck interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck)
}
// SaveProxy mocks base method.
func (m *MockStore) SaveProxy(ctx context.Context, proxy *proxy.Proxy) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveProxy", ctx, proxy)
ret0, _ := ret[0].(error)
return ret0
}
// SaveProxy indicates an expected call of SaveProxy.
func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
}
// SaveProxyAccessToken mocks base method.
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
m.ctrl.T.Helper()
@@ -2762,8 +2791,22 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups)
}
// UpdateProxyHeartbeat mocks base method.
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID)
}
// UpdateService mocks base method.
func (m *MockStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
func (m *MockStore) UpdateService(ctx context.Context, service *service.Service) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateService", ctx, service)
ret0, _ := ret[0].(error)

View File

@@ -18,7 +18,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -100,7 +100,7 @@ type Account struct {
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"`
Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
@@ -906,7 +906,7 @@ func (a *Account) Copy() *Account {
networkResources = append(networkResources, resource.Copy())
}
services := []*reverseproxy.Service{}
services := []*service.Service{}
for _, service := range a.Services {
services = append(services, service.Copy())
}
@@ -1814,7 +1814,7 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
}
}
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
for _, target := range service.Targets {
if !target.Enabled {
continue
@@ -1823,7 +1823,7 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *rever
}
}
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) {
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
port, ok := a.resolveTargetPort(ctx, target)
if !ok {
return
@@ -1840,7 +1840,7 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *revers
}
}
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) {
func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) {
if target.Port != 0 {
return target.Port, true
}
@@ -1856,7 +1856,7 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Ta
}
}
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
return &Policy{
ID: policyID,

View File

@@ -18,8 +18,8 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -37,7 +37,7 @@ type integrationTestSetup struct {
grpcServer *grpc.Server
grpcAddr string
cleanup func()
services []*reverseproxy.Service
services []*service.Service
}
func setupIntegrationTest(t *testing.T) *integrationTestSetup {
@@ -66,13 +66,13 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
privKey := base64.StdEncoding.EncodeToString(priv)
// Create test services in the store
services := []*reverseproxy.Service{
services := []*service.Service{
{
ID: "rp-1",
AccountID: "test-account-1",
Name: "Test App 1",
Domain: "app1.test.proxy.io",
Targets: []*reverseproxy.Target{{
Targets: []*service.Target{{
Path: strPtr("/"),
Host: "10.0.0.1",
Port: 8080,
@@ -91,7 +91,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
AccountID: "test-account-1",
Name: "Test App 2",
Domain: "app2.test.proxy.io",
Targets: []*reverseproxy.Target{{
Targets: []*service.Target{{
Path: strPtr("/"),
Host: "10.0.0.2",
Port: 8080,
@@ -112,7 +112,8 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
}
// Create real token store
tokenStore := nbgrpc.NewOneTimeTokenStore(5 * time.Minute)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
// Create real users manager
usersManager := users.NewManager(testStore)
@@ -124,17 +125,23 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
HMACKey: []byte("test-hmac-key"),
}
proxyManager := &testProxyManager{}
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
oidcConfig,
nil,
usersManager,
proxyManager,
)
// Use store-backed service manager
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
proxyService.SetProxyManager(svcMgr)
proxyService.SetServiceManager(svcMgr)
proxyController := &testProxyController{}
proxyService.SetProxyController(proxyController)
// Start real gRPC server
lis, err := net.Listen("tcp", "127.0.0.1:0")
@@ -185,6 +192,52 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string,
return nil, 0, nil
}
// testProxyManager is a mock implementation of proxy.Manager for testing.
type testProxyManager struct{}
func (m *testProxyManager) Connect(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testProxyManager) Disconnect(_ context.Context, _ string) error {
return nil
}
func (m *testProxyManager) Heartbeat(_ context.Context, _ string) error {
return nil
}
func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) {
return nil, nil
}
func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error {
return nil
}
// testProxyController is a mock implementation of rpservice.ProxyController for testing.
type testProxyController struct{}
func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) {
// noop
}
func (c *testProxyController) GetOIDCValidationConfig() service.OIDCValidationConfig {
return service.OIDCValidationConfig{}
}
func (c *testProxyController) RegisterProxyToCluster(_ context.Context, _, _ string) error {
return nil
}
func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, _, _ string) error {
return nil
}
func (c *testProxyController) GetProxiesForCluster(_ string) []string {
return nil
}
// storeBackedServiceManager reads directly from the real store.
type storeBackedServiceManager struct {
store store.Store
@@ -195,19 +248,19 @@ func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accou
return nil
}
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
}
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, errors.New("not implemented")
}
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, errors.New("not implemented")
}
@@ -219,7 +272,7 @@ func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context,
return nil
}
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
return nil
}
@@ -231,15 +284,15 @@ func (m *storeBackedServiceManager) ReloadService(ctx context.Context, accountID
return nil
}
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1")
}
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
}
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}