[management] Store connected proxies in DB (#5472)

Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
This commit is contained in:
Pascal Fischer
2026-03-03 18:39:46 +01:00
committed by GitHub
parent 05b66e73bc
commit d7c8e37ff4
52 changed files with 1727 additions and 924 deletions

View File

@@ -27,21 +27,21 @@ type store interface {
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
}
type proxyURLProvider interface {
GetConnectedProxyURLs() []string
type proxyManager interface {
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
}
type Manager struct {
store store
validator domain.Validator
proxyURLProvider proxyURLProvider
proxyManager proxyManager
permissionsManager permissions.Manager
}
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
return Manager{
store: store,
proxyURLProvider: proxyURLProvider,
store: store,
proxyManager: proxyMgr,
validator: domain.Validator{
Resolver: net.DefaultResolver,
},
@@ -67,8 +67,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
// Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList := m.proxyURLAllowList()
log.WithFields(log.Fields{
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
log.WithContext(ctx).WithFields(log.Fields{
"accountID": accountID,
"proxyAllowList": allowList,
}).Debug("getting domains with proxy allow list")
@@ -107,7 +111,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
}
// Verify the target cluster is in the available clusters
allowList := m.proxyURLAllowList()
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
clusterValid := false
for _, cluster := range allowList {
if cluster == targetCluster {
@@ -221,25 +228,26 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
}
}
// GetClusterDomains returns a list of proxy cluster domains.
func (m Manager) GetClusterDomains() []string {
return m.proxyURLAllowList()
}
// proxyURLAllowList retrieves a list of currently connected proxies and
// their URLs
func (m Manager) proxyURLAllowList() []string {
var reverseProxyAddresses []string
if m.proxyURLProvider != nil {
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
if m.proxyManager == nil {
return nil
}
return reverseProxyAddresses
addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background())
if err != nil {
return nil
}
return addresses
}
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList := m.proxyURLAllowList()
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
if len(allowList) == 0 {
return "", fmt.Errorf("no proxy clusters available")
}

View File

@@ -0,0 +1,36 @@
package proxy
//go:generate go run github.com/golang/mock/mockgen -package proxy -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"time"
"github.com/netbirdio/netbird/shared/management/proto"
)
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
}
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
type OIDCValidationConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
// Controller is responsible for managing proxy clusters and routing service updates.
type Controller interface {
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
GetOIDCValidationConfig() OIDCValidationConfig
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
}

View File

@@ -0,0 +1,88 @@
package manager
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/shared/management/proto"
)
// GRPCController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
type GRPCController struct {
proxyGRPCServer *nbgrpc.ProxyServiceServer
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
metrics *metrics
}
// NewGRPCController creates a new GRPCController.
func NewGRPCController(proxyGRPCServer *nbgrpc.ProxyServiceServer, meter metric.Meter) (*GRPCController, error) {
m, err := newMetrics(meter)
if err != nil {
return nil, err
}
return &GRPCController{
proxyGRPCServer: proxyGRPCServer,
metrics: m,
}, nil
}
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
func (c *GRPCController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
c.metrics.IncrementServiceUpdateSendCount(clusterAddr)
}
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
func (c *GRPCController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
return c.proxyGRPCServer.GetOIDCValidationConfig()
}
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
func (c *GRPCController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
c.metrics.IncrementProxyConnectionCount(clusterAddr)
return nil
}
// UnregisterProxyFromCluster removes a proxy from a cluster.
func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
c.metrics.DecrementProxyConnectionCount(clusterAddr)
}
return nil
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)
if !ok {
return nil
}
var proxies []string
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxies = append(proxies, key.(string))
return true
})
return proxies
}

View File

@@ -0,0 +1,115 @@
package manager
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// store defines the interface for proxy persistence operations
type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
}
// Manager handles all proxy operations
type Manager struct {
store store
metrics *metrics
}
// NewManager creates a new proxy Manager
func NewManager(store store, meter metric.Meter) (*Manager, error) {
m, err := newMetrics(meter)
if err != nil {
return nil, err
}
return &Manager{
store: store,
metrics: m,
}, nil
}
// Connect registers a new proxy connection in the database
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
"clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return nil
}
// Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
}).Info("proxy disconnected")
return nil
}
// Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
}
m.metrics.IncrementProxyHeartbeatCount()
return nil
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
return addresses, nil
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}

View File

@@ -0,0 +1,74 @@
package manager
import (
"context"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
type metrics struct {
proxyConnectionCount metric.Int64UpDownCounter
serviceUpdateSendCount metric.Int64Counter
proxyHeartbeatCount metric.Int64Counter
}
func newMetrics(meter metric.Meter) (*metrics, error) {
proxyConnectionCount, err := meter.Int64UpDownCounter(
"management_proxy_connection_count",
metric.WithDescription("Number of active proxy connections"),
metric.WithUnit("{connection}"),
)
if err != nil {
return nil, err
}
serviceUpdateSendCount, err := meter.Int64Counter(
"management_proxy_service_update_send_count",
metric.WithDescription("Total number of service updates sent to proxies"),
metric.WithUnit("{update}"),
)
if err != nil {
return nil, err
}
proxyHeartbeatCount, err := meter.Int64Counter(
"management_proxy_heartbeat_count",
metric.WithDescription("Total number of proxy heartbeats received"),
metric.WithUnit("{heartbeat}"),
)
if err != nil {
return nil, err
}
return &metrics{
proxyConnectionCount: proxyConnectionCount,
serviceUpdateSendCount: serviceUpdateSendCount,
proxyHeartbeatCount: proxyHeartbeatCount,
}, nil
}
func (m *metrics) IncrementProxyConnectionCount(clusterAddr string) {
m.proxyConnectionCount.Add(context.Background(), 1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) DecrementProxyConnectionCount(clusterAddr string) {
m.proxyConnectionCount.Add(context.Background(), -1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) IncrementServiceUpdateSendCount(clusterAddr string) {
m.serviceUpdateSendCount.Add(context.Background(), 1,
metric.WithAttributes(
attribute.String("cluster", clusterAddr),
))
}
func (m *metrics) IncrementProxyHeartbeatCount() {
m.proxyHeartbeatCount.Add(context.Background(), 1)
}

View File

@@ -0,0 +1,199 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./manager.go
// Package proxy is a generated GoMock package.
package proxy
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
proto "github.com/netbirdio/netbird/shared/management/proto"
)
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CleanupStale mocks base method.
func (m *MockManager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupStale", ctx, inactivityDuration)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupStale indicates an expected call of CleanupStale.
func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
}
// Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error)
return ret0
}
// Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
}
// Disconnect mocks base method.
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// Disconnect indicates an expected call of Disconnect.
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
}
// GetActiveClusterAddresses mocks base method.
func (m *MockManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusterAddresses", ctx)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusterAddresses indicates an expected call of GetActiveClusterAddresses.
func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// Heartbeat indicates an expected call of Heartbeat.
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
}
// MockController is a mock of Controller interface.
type MockController struct {
ctrl *gomock.Controller
recorder *MockControllerMockRecorder
}
// MockControllerMockRecorder is the mock recorder for MockController.
type MockControllerMockRecorder struct {
mock *MockController
}
// NewMockController creates a new mock instance.
func NewMockController(ctrl *gomock.Controller) *MockController {
mock := &MockController{ctrl: ctrl}
mock.recorder = &MockControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockController) EXPECT() *MockControllerMockRecorder {
return m.recorder
}
// GetOIDCValidationConfig mocks base method.
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
ret0, _ := ret[0].(OIDCValidationConfig)
return ret0
}
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig))
}
// GetProxiesForCluster mocks base method.
func (m *MockController) GetProxiesForCluster(clusterAddr string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
ret0, _ := ret[0].([]string)
return ret0
}
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
func (mr *MockControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockController)(nil).GetProxiesForCluster), clusterAddr)
}
// RegisterProxyToCluster mocks base method.
func (m *MockController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
func (mr *MockControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
}
// SendServiceUpdateToCluster mocks base method.
func (m *MockController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SendServiceUpdateToCluster", ctx, accountID, update, clusterAddr)
}
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
func (mr *MockControllerMockRecorder) SendServiceUpdateToCluster(ctx, accountID, update, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockController)(nil).SendServiceUpdateToCluster), ctx, accountID, update, clusterAddr)
}
// UnregisterProxyFromCluster mocks base method.
func (m *MockController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
func (mr *MockControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
}

View File

@@ -0,0 +1,20 @@
package proxy
import "time"
// Proxy represents a reverse proxy instance
type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (Proxy) TableName() string {
return "proxies"
}

View File

@@ -1,6 +1,6 @@
package reverseproxy
package service
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import (
"context"
@@ -14,7 +14,7 @@ type Manager interface {
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
DeleteAllServices(ctx context.Context, accountID, userID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error)

View File

@@ -1,8 +1,8 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go
// Package reverseproxy is a generated GoMock package.
package reverseproxy
// Package service is a generated GoMock package.
package service
import (
context "context"
@@ -239,7 +239,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
}
// SetStatus mocks base method.
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
ret0, _ := ret[0].(error)

View File

@@ -6,10 +6,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -17,11 +17,11 @@ import (
)
type handler struct {
manager reverseproxy.Manager
manager rpservice.Manager
}
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{
manager: manager,
}
@@ -72,7 +72,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
return
}
service := new(reverseproxy.Service)
service := new(rpservice.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
@@ -130,7 +130,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
return
}
service := new(reverseproxy.Service)
service := new(rpservice.Service)
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)

View File

@@ -27,7 +27,7 @@ type trackedExpose struct {
type exposeTracker struct {
activeExposes sync.Map
exposeCreateMu sync.Mutex
manager *managerImpl
manager *Manager
}
func exposeKey(peerID, domain string) string {

View File

@@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
)
func TestExposeKey(t *testing.T) {
@@ -120,7 +120,7 @@ func TestReapExpiredExposes(t *testing.T) {
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
@@ -156,7 +156,7 @@ func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) {
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
@@ -191,7 +191,7 @@ func TestConcurrentTrackAndCount(t *testing.T) {
ctx := context.Background()
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})

View File

@@ -11,17 +11,15 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -33,24 +31,22 @@ type ClusterDeriver interface {
GetClusterDomains() []string
}
type managerImpl struct {
type Manager struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
settingsManager settings.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
proxyController proxy.Controller
clusterDeriver ClusterDeriver
exposeTracker *exposeTracker
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, settingsManager settings.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
mgr := &managerImpl{
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
mgr := &Manager{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
settingsManager: settingsManager,
proxyGRPCServer: proxyGRPCServer,
proxyController: proxyController,
clusterDeriver: clusterDeriver,
}
mgr.exposeTracker = &exposeTracker{manager: mgr}
@@ -58,11 +54,11 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
}
// StartExposeReaper delegates to the expose tracker.
func (m *managerImpl) StartExposeReaper(ctx context.Context) {
func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeTracker.StartExposeReaper(ctx)
}
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -86,34 +82,34 @@ func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID stri
return services, nil
}
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
for _, target := range service.Targets {
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
for _, target := range s.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
case service.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost:
case service.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Prefix.Addr().String()
case reverseproxy.TargetTypeDomain:
case service.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder
continue
}
target.Host = resource.Domain
case reverseproxy.TargetTypeSubnet:
case service.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
@@ -122,7 +118,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string,
return nil
}
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -143,7 +139,7 @@ func (m *managerImpl) GetService(ctx context.Context, accountID, userID, service
return service, nil
}
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -152,29 +148,29 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
return nil, status.NewPermissionDeniedError()
}
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
return nil, err
}
if err := m.persistNewService(ctx, accountID, service); err != nil {
if err := m.persistNewService(ctx, accountID, s); err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
return s, nil
}
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
if m.clusterDeriver != nil {
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
@@ -201,7 +197,7 @@ func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID
return nil
}
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
return err
@@ -219,7 +215,7 @@ func (m *managerImpl) persistNewService(ctx context.Context, accountID string, s
})
}
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
@@ -235,7 +231,7 @@ func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction stor
return nil
}
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -259,7 +255,7 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.sendServiceUpdateNotifications(service, updateInfo)
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
@@ -271,7 +267,7 @@ type serviceUpdateInfo struct {
serviceEnabledChanged bool
}
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -309,7 +305,7 @@ func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string
return &updateInfo, err
}
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
return err
}
@@ -326,7 +322,7 @@ func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.
return nil
}
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
@@ -340,54 +336,40 @@ func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reve
}
}
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyController.GetOIDCValidationConfig()
switch {
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
m.sendServiceUpdate(service, reverseproxy.Delete, updateInfo.oldCluster, "")
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
case !service.Enabled && updateInfo.serviceEnabledChanged:
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
case service.Enabled && updateInfo.serviceEnabledChanged:
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
case !s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
case s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
default:
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
}
}
func (m *managerImpl) sendServiceUpdate(service *reverseproxy.Service, operation reverseproxy.Operation, cluster, oldService string) {
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
mapping := service.ToProtoMapping(operation, oldService, oidcCfg)
m.sendMappingsToCluster([]*proto.ProxyMapping{mapping}, cluster)
}
func (m *managerImpl) sendMappingsToCluster(mappings []*proto.ProxyMapping, cluster string) {
if len(mappings) == 0 {
return
}
update := &proto.GetMappingUpdateResponse{
Mapping: mappings,
}
m.proxyGRPCServer.SendServiceUpdateToCluster(update, cluster)
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
for _, target := range targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
case service.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
@@ -399,7 +381,7 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil
}
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -408,9 +390,10 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return status.NewPermissionDeniedError()
}
var service *reverseproxy.Service
var s *service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
var err error
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
@@ -429,20 +412,20 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return err
}
if service.Source == reverseproxy.SourceEphemeral {
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain)
if s.Source == service.SourceEphemeral {
m.exposeTracker.UntrackExpose(s.SourcePeer, s.Domain)
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID string) error {
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
@@ -451,16 +434,16 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
return status.NewPermissionDeniedError()
}
var services []*reverseproxy.Service
var services []*service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
services, err = transaction.GetServicesByAccountID(ctx, store.LockingStrengthUpdate, accountID)
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return err
}
for _, service := range services {
if err = transaction.DeleteService(ctx, accountID, service.ID); err != nil {
for _, svc := range services {
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
}
@@ -471,20 +454,14 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
return err
}
clusterMappings := make(map[string][]*proto.ProxyMapping)
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
oidcCfg := m.proxyController.GetOIDCValidationConfig()
for _, service := range services {
if service.Source == reverseproxy.SourceEphemeral {
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain)
for _, svc := range services {
if svc.Source == service.SourceEphemeral {
m.exposeTracker.UntrackExpose(svc.SourcePeer, svc.Domain)
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, service.EventMeta())
mapping := service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg)
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
}
for cluster, mappings := range clusterMappings {
m.sendMappingsToCluster(mappings, cluster)
m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
@@ -494,7 +471,7 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued.
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
@@ -513,7 +490,7 @@ func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, ser
}
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
@@ -530,50 +507,42 @@ func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string
})
}
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
err = m.replaceHostByLookup(ctx, accountID, service)
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("failed to get services: %w", err)
}
clusterMappings := make(map[string][]*proto.ProxyMapping)
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
for _, service := range services {
err = m.replaceHostByLookup(ctx, accountID, service)
for _, s := range services {
err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
}
mapping := service.ToProtoMapping(reverseproxy.Update, "", oidcCfg)
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
}
for cluster, mappings := range clusterMappings {
m.sendMappingsToCluster(mappings, cluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
}
return nil
}
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
@@ -589,7 +558,7 @@ func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Se
return services, nil
}
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
@@ -603,7 +572,7 @@ func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID s
return service, nil
}
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
@@ -619,7 +588,7 @@ func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string)
return services, nil
}
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
@@ -637,7 +606,7 @@ func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID stri
// validateExposePermission checks whether the peer is allowed to use the expose feature.
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, peerID string) error {
func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
@@ -670,7 +639,7 @@ func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, p
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
if err := req.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
}
@@ -679,31 +648,31 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
return nil, err
}
serviceName, err := reverseproxy.GenerateExposeName(req.NamePrefix)
serviceName, err := service.GenerateExposeName(req.NamePrefix)
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
}
service := req.ToService(accountID, peerID, serviceName)
service.Source = reverseproxy.SourceEphemeral
svc := req.ToService(accountID, peerID, serviceName)
svc.Source = service.SourceEphemeral
if service.Domain == "" {
domain, err := m.buildRandomDomain(service.Name)
if svc.Domain == "" {
domain, err := m.buildRandomDomain(svc.Name)
if err != nil {
return nil, fmt.Errorf("build random domain for service %s: %w", service.Name, err)
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
}
service.Domain = domain
svc.Domain = domain
}
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, service.Auth.BearerAuth.DistributionGroups)
if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled {
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups)
if err != nil {
return nil, fmt.Errorf("get group ids for service %s: %w", service.Name, err)
return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err)
}
service.Auth.BearerAuth.DistributionGroups = groupIDs
svc.Auth.BearerAuth.DistributionGroups = groupIDs
}
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil {
return nil, err
}
@@ -713,45 +682,45 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
}
now := time.Now()
service.Meta.LastRenewedAt = &now
service.SourcePeer = peerID
svc.Meta.LastRenewedAt = &now
svc.SourcePeer = peerID
if err := m.persistNewService(ctx, accountID, service); err != nil {
if err := m.persistNewService(ctx, accountID, svc); err != nil {
return nil, err
}
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, service.Domain, accountID)
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, svc.Domain, accountID)
if alreadyTracked {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", service.Domain, err)
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", svc.Domain, err)
}
return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
if !allowed {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", service.Domain, err)
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", svc.Domain, err)
}
return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
meta := addPeerInfoToEventMeta(service.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, service.ID, accountID, activity.PeerServiceExposed, meta)
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta)
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
return nil, fmt.Errorf("replace host by lookup for service %s: %w", service.ID, err)
if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil {
return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err)
}
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return &reverseproxy.ExposeServiceResponse{
ServiceName: service.Name,
ServiceURL: "https://" + service.Domain,
Domain: service.Domain,
return &service.ExposeServiceResponse{
ServiceName: svc.Name,
ServiceURL: "https://" + svc.Domain,
Domain: svc.Domain,
}, nil
}
func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
if len(groupNames) == 0 {
return []string{}, fmt.Errorf("no group names provided")
}
@@ -766,7 +735,7 @@ func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string
return groupIDs, nil
}
func (m *managerImpl) buildRandomDomain(name string) (string, error) {
func (m *Manager) buildRandomDomain(name string) (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
}
@@ -781,7 +750,7 @@ func (m *managerImpl) buildRandomDomain(name string) (string, error) {
// RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session.
// Returns an error if the expose is not actively tracked.
func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error {
func (m *Manager) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error {
if !m.exposeTracker.RenewTrackedExpose(peerID, domain) {
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
}
@@ -789,7 +758,7 @@ func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain
}
// StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service.
func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
return err
@@ -804,8 +773,8 @@ func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
service, err := m.lookupPeerService(ctx, accountID, peerID, domain)
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
@@ -814,41 +783,41 @@ func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peer
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, service.ID, activityCode)
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *managerImpl) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*reverseproxy.Service, error) {
service, err := m.store.GetServiceByDomain(ctx, accountID, domain)
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
return nil, err
}
if service.Source != reverseproxy.SourceEphemeral {
if svc.Source != service.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if service.SourcePeer != peerID {
if svc.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return service, nil
return svc, nil
}
func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
var service *reverseproxy.Service
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
var svc *service.Service
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
return err
}
if service.Source != reverseproxy.SourceEphemeral {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose")
}
if service.SourcePeer != peerID {
if svc.SourcePeer != peerID {
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
}
@@ -868,11 +837,11 @@ func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID,
peer = nil
}
meta := addPeerInfoToEventMeta(service.EventMeta(), peer)
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta)
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)

View File

@@ -10,21 +10,21 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -33,13 +33,13 @@ func TestInitializeServiceForCreate(t *testing.T) {
accountID := "test-account"
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
mgr := &managerImpl{
mgr := &Manager{
clusterDeriver: nil,
}
service := &reverseproxy.Service{
service := &rpservice.Service{
Domain: "example.com",
Auth: reverseproxy.AuthConfig{},
Auth: rpservice.AuthConfig{},
}
err := mgr.initializeServiceForCreate(ctx, accountID, service)
@@ -53,12 +53,12 @@ func TestInitializeServiceForCreate(t *testing.T) {
})
t.Run("verifies session keys are different", func(t *testing.T) {
mgr := &managerImpl{
mgr := &Manager{
clusterDeriver: nil,
}
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
service1 := &rpservice.Service{Domain: "test1.com", Auth: rpservice.AuthConfig{}}
service2 := &rpservice.Service{Domain: "test2.com", Auth: rpservice.AuthConfig{}}
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
@@ -100,7 +100,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
@@ -112,7 +112,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: false,
},
@@ -123,7 +123,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
@@ -149,7 +149,7 @@ func TestCheckDomainAvailable(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
tt.setupMock(mockStore)
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
if tt.expectedError {
@@ -179,7 +179,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
GetServiceByDomain(ctx, accountID, "").
Return(nil, status.Errorf(status.NotFound, "not found"))
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
assert.NoError(t, err)
@@ -192,9 +192,9 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "test.com").
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
assert.Error(t, err)
@@ -212,7 +212,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
GetServiceByDomain(ctx, accountID, "nil.com").
Return(nil, nil)
mgr := &managerImpl{}
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
assert.NoError(t, err)
@@ -228,10 +228,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
service := &rpservice.Service{
ID: "service-123",
Domain: "new.com",
Targets: []*reverseproxy.Target{},
Targets: []*rpservice.Target{},
}
// Mock ExecuteInTransaction to execute the function immediately
@@ -250,7 +250,7 @@ func TestPersistNewService(t *testing.T) {
return fn(txMock)
})
mgr := &managerImpl{store: mockStore}
mgr := &Manager{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service)
assert.NoError(t, err)
@@ -261,10 +261,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
service := &rpservice.Service{
ID: "service-123",
Domain: "existing.com",
Targets: []*reverseproxy.Target{},
Targets: []*rpservice.Target{},
}
mockStore.EXPECT().
@@ -273,12 +273,12 @@ func TestPersistNewService(t *testing.T) {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "existing.com").
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
return fn(txMock)
})
mgr := &managerImpl{store: mockStore}
mgr := &Manager{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service)
require.Error(t, err)
@@ -288,21 +288,21 @@ func TestPersistNewService(t *testing.T) {
})
}
func TestPreserveExistingAuthSecrets(t *testing.T) {
mgr := &managerImpl{}
mgr := &Manager{}
t.Run("preserve password when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "hashed-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "",
},
@@ -315,18 +315,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
})
t.Run("preserve pin when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PinAuth: &rpservice.PINAuthConfig{
Enabled: true,
Pin: "hashed-pin",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PinAuth: &rpservice.PINAuthConfig{
Enabled: true,
Pin: "",
},
@@ -339,18 +339,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
})
t.Run("do not preserve when password is provided", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "old-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "new-password",
},
@@ -365,10 +365,10 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
}
func TestPreserveServiceMetadata(t *testing.T) {
mgr := &managerImpl{}
mgr := &Manager{}
existing := &reverseproxy.Service{
Meta: reverseproxy.ServiceMeta{
existing := &rpservice.Service{
Meta: rpservice.Meta{
CertificateIssuedAt: func() *time.Time { t := time.Now(); return &t }(),
Status: "active",
},
@@ -376,7 +376,7 @@ func TestPreserveServiceMetadata(t *testing.T) {
SessionPublicKey: "public-key",
}
updated := &reverseproxy.Service{
updated := &rpservice.Service{
Domain: "updated.com",
}
@@ -400,31 +400,32 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
IP: net.ParseIP("100.64.0.1"),
}
newEphemeralService := func() *reverseproxy.Service {
return &reverseproxy.Service{
newEphemeralService := func() *rpservice.Service {
return &rpservice.Service{
ID: serviceID,
AccountID: accountID,
Name: "test-service",
Domain: "test.example.com",
Source: reverseproxy.SourceEphemeral,
Source: rpservice.SourceEphemeral,
SourcePeer: ownerPeerID,
}
}
newPermanentService := func() *reverseproxy.Service {
return &reverseproxy.Service{
newPermanentService := func() *rpservice.Service {
return &rpservice.Service{
ID: serviceID,
AccountID: accountID,
Name: "api-service",
Domain: "api.example.com",
Source: reverseproxy.SourcePermanent,
Source: rpservice.SourcePermanent,
}
}
newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer {
t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Hour)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(srv.Close)
return srv
}
@@ -458,10 +459,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil)
mgr := &managerImpl{
store: mockStore,
accountManager: mockAccountMgr,
proxyGRPCServer: newProxyServer(t),
mgr := &Manager{
store: mockStore,
accountManager: mockAccountMgr,
proxyController: func() proxy.Controller {
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
return c
}(),
}
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed)
@@ -485,7 +490,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
return fn(txMock)
})
mgr := &managerImpl{
mgr := &Manager{
store: mockStore,
}
@@ -514,7 +519,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
return fn(txMock)
})
mgr := &managerImpl{
mgr := &Manager{
store: mockStore,
}
@@ -556,10 +561,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil)
mgr := &managerImpl{
store: mockStore,
accountManager: mockAccountMgr,
proxyGRPCServer: newProxyServer(t),
mgr := &Manager{
store: mockStore,
accountManager: mockAccountMgr,
proxyController: func() proxy.Controller {
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
return c
}(),
}
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceExposeExpired)
@@ -596,10 +605,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID).
Return(testPeer, nil)
mgr := &managerImpl{
store: mockStore,
accountManager: mockAccountMgr,
proxyGRPCServer: newProxyServer(t),
mgr := &Manager{
store: mockStore,
accountManager: mockAccountMgr,
proxyController: func() proxy.Controller {
c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
return c
}(),
}
err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed)
@@ -612,19 +625,6 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
})
}
// noopExtraSettings is a minimal extra_settings.Manager for tests without external integrations.
type noopExtraSettings struct{}
func (n *noopExtraSettings) GetExtraSettings(_ context.Context, _ string) (*types.ExtraSettings, error) {
return &types.ExtraSettings{}, nil
}
func (n *noopExtraSettings) UpdateExtraSettings(_ context.Context, _, _ string, _ *types.ExtraSettings) (bool, error) {
return false, nil
}
var _ extra_settings.Manager = (*noopExtraSettings)(nil)
// testClusterDeriver is a minimal ClusterDeriver that returns a fixed domain list.
type testClusterDeriver struct {
domains []string
@@ -646,7 +646,7 @@ const (
)
// setupIntegrationTest creates a real SQLite store with seeded test data for integration tests.
func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) {
func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
t.Helper()
ctx := context.Background()
@@ -694,30 +694,28 @@ func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) {
require.NoError(t, err)
permsMgr := permissions.NewManager(testStore)
usersMgr := users.NewManager(testStore)
settingsMgr := settings.NewManager(testStore, usersMgr, &noopExtraSettings{}, permsMgr, settings.IdpConfig{})
var storedEvents []activity.Activity
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
storedEvents = append(storedEvents, activityID.(activity.Activity))
},
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
},
}
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Hour)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
mgr := &managerImpl{
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
mgr := &Manager{
store: testStore,
accountManager: accountMgr,
permissionsManager: permsMgr,
settingsManager: settingsMgr,
proxyGRPCServer: proxySrv,
proxyController: proxyController,
clusterDeriver: &testClusterDeriver{
domains: []string{"test.netbird.io"},
},
@@ -791,7 +789,7 @@ func Test_validateExposePermission(t *testing.T) {
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountSettings(gomock.Any(), gomock.Any(), testAccountID).Return(nil, errors.New("store error"))
mgr := &managerImpl{store: mockStore}
mgr := &Manager{store: mockStore}
err := mgr.validateExposePermission(ctx, testAccountID, testPeerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "get account settings")
@@ -804,7 +802,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
t.Run("creates service with random domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
@@ -819,7 +817,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err)
assert.Equal(t, resp.Domain, persisted.Domain)
assert.Equal(t, reverseproxy.SourceEphemeral, persisted.Source, "source should be ephemeral")
assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
assert.Equal(t, testPeerID, persisted.SourcePeer, "source peer should be set")
assert.NotNil(t, persisted.Meta.LastRenewedAt, "last renewed should be set")
})
@@ -827,7 +825,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
t.Run("creates service with custom domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 80,
Protocol: "http",
Domain: "example.com",
@@ -848,7 +846,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
@@ -861,7 +859,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
t.Run("validates request fields", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 0,
Protocol: "http",
}
@@ -875,67 +873,67 @@ func TestCreateServiceFromPeer(t *testing.T) {
func TestExposeServiceRequestValidate(t *testing.T) {
tests := []struct {
name string
req reverseproxy.ExposeServiceRequest
req rpservice.ExposeServiceRequest
wantErr string
}{
{
name: "valid http request",
req: reverseproxy.ExposeServiceRequest{Port: 8080, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"},
wantErr: "",
},
{
name: "valid https request with pin",
req: reverseproxy.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
wantErr: "",
},
{
name: "port zero rejected",
req: reverseproxy.ExposeServiceRequest{Port: 0, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "negative port rejected",
req: reverseproxy.ExposeServiceRequest{Port: -1, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "port above 65535 rejected",
req: reverseproxy.ExposeServiceRequest{Port: 65536, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "unsupported protocol",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
wantErr: "unsupported protocol",
},
{
name: "invalid pin format",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
wantErr: "invalid pin",
},
{
name: "pin too short",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
wantErr: "invalid pin",
},
{
name: "valid 6-digit pin",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
wantErr: "",
},
{
name: "empty user group name",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
wantErr: "user group name cannot be empty",
},
{
name: "invalid name prefix",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
wantErr: "invalid name prefix",
},
{
name: "valid name prefix",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
wantErr: "",
},
}
@@ -953,7 +951,7 @@ func TestExposeServiceRequestValidate(t *testing.T) {
}
t.Run("nil receiver", func(t *testing.T) {
var req *reverseproxy.ExposeServiceRequest
var req *rpservice.ExposeServiceRequest
err := req.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "request cannot be nil")
@@ -967,7 +965,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
// First create a service
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
@@ -986,7 +984,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
t.Run("expire uses correct activity", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
@@ -1004,7 +1002,7 @@ func TestStopServiceFromPeer(t *testing.T) {
t.Run("stops service by domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
@@ -1023,7 +1021,7 @@ func TestDeleteService_UntracksEphemeralExpose(t *testing.T) {
ctx := context.Background()
mgr, _ := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
@@ -1041,7 +1039,7 @@ func TestDeleteService_UntracksEphemeralExpose(t *testing.T) {
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be untracked after API delete")
// A new expose should succeed (not blocked by stale tracking)
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9090,
Protocol: "http",
})
@@ -1053,7 +1051,7 @@ func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
for i := range 3 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
@@ -1074,7 +1072,7 @@ func TestRenewServiceFromPeer(t *testing.T) {
t.Run("renews tracked expose", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
@@ -1129,25 +1127,32 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
mockPerms := permissions.NewMockManager(ctrl)
mockAcct := account.NewMockManager(ctrl)
mockGRPC := &nbgrpc.ProxyServiceServer{}
mgr := &managerImpl{
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
mgr := &Manager{
store: sqlStore,
permissionsManager: mockPerms,
accountManager: mockAcct,
proxyGRPCServer: mockGRPC,
proxyController: proxyController,
}
service := &reverseproxy.Service{
service := &rpservice.Service{
ID: "service-1",
AccountID: accountID,
Domain: "test.example.com",
ProxyCluster: "cluster1",
Enabled: true,
Targets: []*reverseproxy.Target{
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-1"},
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-2"},
{AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-3"},
Targets: []*rpservice.Target{
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-1"},
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-2"},
{AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-3"},
},
}

View File

@@ -1,4 +1,4 @@
package reverseproxy
package service
import (
"crypto/rand"
@@ -14,6 +14,7 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/util/crypt"
@@ -29,15 +30,15 @@ const (
Delete Operation = "delete"
)
type ProxyStatus string
type Status string
const (
StatusPending ProxyStatus = "pending"
StatusActive ProxyStatus = "active"
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
StatusCertificatePending ProxyStatus = "certificate_pending"
StatusCertificateFailed ProxyStatus = "certificate_failed"
StatusError ProxyStatus = "error"
StatusPending Status = "pending"
StatusActive Status = "active"
StatusTunnelNotCreated Status = "tunnel_not_created"
StatusCertificatePending Status = "certificate_pending"
StatusCertificateFailed Status = "certificate_failed"
StatusError Status = "error"
TargetTypePeer = "peer"
TargetTypeHost = "host"
@@ -111,14 +112,7 @@ func (a *AuthConfig) ClearSecrets() {
}
}
type OIDCValidationConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
type ServiceMeta struct {
type Meta struct {
CreatedAt time.Time
CertificateIssuedAt *time.Time
Status string
@@ -135,11 +129,11 @@ type Service struct {
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent'"`
Auth AuthConfig `gorm:"serializer:json"`
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent'"`
SourcePeer string
}
@@ -165,7 +159,7 @@ func NewService(accountID, name, domain, proxyCluster string, targets []*Target,
// only be called during initial creation, not for updates.
func (s *Service) InitNewRecord() {
s.ID = xid.New().String()
s.Meta = ServiceMeta{
s.Meta = Meta{
CreatedAt: time.Now(),
Status: string(StatusPending),
}
@@ -239,7 +233,7 @@ func (s *Service) ToAPIResponse() *api.Service {
return resp
}
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {

View File

@@ -1,4 +1,4 @@
package reverseproxy
package service
import (
"errors"
@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
)
@@ -109,7 +110,7 @@ func TestIsDefaultPort(t *testing.T) {
}
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
oidcConfig := OIDCValidationConfig{}
oidcConfig := proxy.OIDCValidationConfig{}
tests := []struct {
name string
@@ -202,7 +203,7 @@ func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
},
}
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
}
@@ -219,7 +220,7 @@ func TestToProtoMapping_OperationTypes(t *testing.T) {
}
for _, tt := range tests {
t.Run(string(tt.op), func(t *testing.T) {
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
pm := rp.ToProtoMapping(tt.op, "", proxy.OIDCValidationConfig{})
assert.Equal(t, tt.want, pm.Type)
})
}