Compare commits

...

9 Commits

Author SHA1 Message Date
pascal
5c049f6f09 delete tests 2026-02-24 16:53:49 +01:00
pascal
740c726a78 extract proxy controller 2026-02-23 15:28:20 +01:00
pascal
3af287ebab move service manager 2026-02-20 01:21:05 +01:00
pascal
d4d885d434 Merge remote-tracking branch 'origin/main' into feature/add-serial-to-proxy 2026-02-20 00:35:10 +01:00
pascal
d212332f5d store proxies in DB 2026-02-20 00:28:45 +01:00
pascal
0e11258e97 fix test 2026-02-19 13:30:26 +01:00
pascal
31ecf8f1f5 allow redis for token store 2026-02-19 13:28:36 +01:00
pascal
e2df1fb35e add accountID when sending update to cluster 2026-02-19 12:02:31 +01:00
pascal
942cd5dc72 export methods 2026-02-19 10:50:32 +01:00
47 changed files with 949 additions and 3827 deletions

View File

@@ -27,21 +27,21 @@ type store interface {
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
} }
type proxyURLProvider interface { type proxyManager interface {
GetConnectedProxyURLs() []string GetActiveClusterAddresses(ctx context.Context) ([]string, error)
} }
type Manager struct { type Manager struct {
store store store store
validator domain.Validator validator domain.Validator
proxyURLProvider proxyURLProvider proxyManager proxyManager
permissionsManager permissions.Manager permissionsManager permissions.Manager
} }
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager { func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
return Manager{ return Manager{
store: store, store: store,
proxyURLProvider: proxyURLProvider, proxyManager: proxyMgr,
validator: domain.Validator{ validator: domain.Validator{
Resolver: net.DefaultResolver, Resolver: net.DefaultResolver,
}, },
@@ -67,8 +67,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
// Add connected proxy clusters as free domains. // Add connected proxy clusters as free domains.
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io"). // The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
allowList := m.proxyURLAllowList() allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
log.WithFields(log.Fields{ if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
log.WithContext(ctx).WithFields(log.Fields{
"accountID": accountID, "accountID": accountID,
"proxyAllowList": allowList, "proxyAllowList": allowList,
}).Debug("getting domains with proxy allow list") }).Debug("getting domains with proxy allow list")
@@ -107,7 +111,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
} }
// Verify the target cluster is in the available clusters // Verify the target cluster is in the available clusters
allowList := m.proxyURLAllowList() allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
clusterValid := false clusterValid := false
for _, cluster := range allowList { for _, cluster := range allowList {
if cluster == targetCluster { if cluster == targetCluster {
@@ -221,21 +228,14 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
} }
} }
// proxyURLAllowList retrieves a list of currently connected proxies and
// their URLs
func (m Manager) proxyURLAllowList() []string {
var reverseProxyAddresses []string
if m.proxyURLProvider != nil {
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
}
return reverseProxyAddresses
}
// DeriveClusterFromDomain determines the proxy cluster for a given domain. // DeriveClusterFromDomain determines the proxy cluster for a given domain.
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. // For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster. // For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) { func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
allowList := m.proxyURLAllowList() allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
if err != nil {
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
}
if len(allowList) == 0 { if len(allowList) == 0 {
return "", fmt.Errorf("no proxy clusters available") return "", fmt.Errorf("no proxy clusters available")
} }

View File

@@ -0,0 +1,15 @@
package proxy
import (
"context"
"time"
)
// Manager defines the interface for proxy operations
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
}

View File

@@ -0,0 +1,106 @@
package manager
import (
"context"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// store defines the interface for proxy persistence operations
type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
}
// Manager handles all proxy operations
type Manager struct {
store store
}
// NewManager creates a new proxy Manager
func NewManager(store store) Manager {
return Manager{
store: store,
}
}
// Connect registers a new proxy connection in the database
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
ClusterAddress: clusterAddress,
IPAddress: ipAddress,
LastSeen: now,
ConnectedAt: &now,
Status: "connected",
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
"clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return nil
}
// Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
}).Info("proxy disconnected")
return nil
}
// Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
}
return nil
}
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
return nil, err
}
return addresses, nil
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
return err
}
return nil
}

View File

@@ -0,0 +1,20 @@
package proxy
import "time"
// Proxy represents a reverse proxy instance
type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time
DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
CreatedAt time.Time
UpdatedAt time.Time
}
func (Proxy) TableName() string {
return "proxies"
}

View File

@@ -1,9 +1,11 @@
package reverseproxy package service
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod //go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import ( import (
"context" "context"
"github.com/netbirdio/netbird/shared/management/proto"
) )
type Manager interface { type Manager interface {
@@ -13,7 +15,7 @@ type Manager interface {
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
DeleteService(ctx context.Context, accountID, userID, serviceID string) error DeleteService(ctx context.Context, accountID, userID, serviceID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error) GetGlobalServices(ctx context.Context) ([]*Service, error)
@@ -21,3 +23,12 @@ type Manager interface {
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
} }
// ProxyController is responsible for managing proxy clusters and routing service updates.
type ProxyController interface {
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
GetOIDCValidationConfig() OIDCValidationConfig
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
}

View File

@@ -1,14 +1,15 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go // Source: ./interface.go
// Package reverseproxy is a generated GoMock package. // Package service is a generated GoMock package.
package reverseproxy package service
import ( import (
context "context" context "context"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
proto "github.com/netbirdio/netbird/shared/management/proto"
) )
// MockManager is a mock of Manager interface. // MockManager is a mock of Manager interface.
@@ -196,7 +197,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
} }
// SetStatus mocks base method. // SetStatus mocks base method.
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error { func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status) ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -223,3 +224,94 @@ func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
} }
// MockProxyController is a mock of ProxyController interface.
type MockProxyController struct {
ctrl *gomock.Controller
recorder *MockProxyControllerMockRecorder
}
// MockProxyControllerMockRecorder is the mock recorder for MockProxyController.
type MockProxyControllerMockRecorder struct {
mock *MockProxyController
}
// NewMockProxyController creates a new mock instance.
func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController {
mock := &MockProxyController{ctrl: ctrl}
mock.recorder = &MockProxyControllerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder {
return m.recorder
}
// GetOIDCValidationConfig mocks base method.
func (m *MockProxyController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
ret0, _ := ret[0].(OIDCValidationConfig)
return ret0
}
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
func (mr *MockProxyControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockProxyController)(nil).GetOIDCValidationConfig))
}
// GetProxiesForCluster mocks base method.
func (m *MockProxyController) GetProxiesForCluster(clusterAddr string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
ret0, _ := ret[0].([]string)
return ret0
}
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
func (mr *MockProxyControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockProxyController)(nil).GetProxiesForCluster), clusterAddr)
}
// RegisterProxyToCluster mocks base method.
func (m *MockProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
func (mr *MockProxyControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockProxyController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
}
// SendServiceUpdateToCluster mocks base method.
func (m *MockProxyController) SendServiceUpdateToCluster(accountID string, update *proto.ProxyMapping, clusterAddr string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SendServiceUpdateToCluster", accountID, update, clusterAddr)
}
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
func (mr *MockProxyControllerMockRecorder) SendServiceUpdateToCluster(accountID, update, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockProxyController)(nil).SendServiceUpdateToCluster), accountID, update, clusterAddr)
}
// UnregisterProxyFromCluster mocks base method.
func (m *MockProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
func (mr *MockProxyControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockProxyController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
}

View File

@@ -6,10 +6,10 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context" nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/http/util"
@@ -17,11 +17,11 @@ import (
) )
type handler struct { type handler struct {
manager reverseproxy.Manager manager rpservice.Manager
} }
// RegisterEndpoints registers all service HTTP endpoints. // RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{ h := &handler{
manager: manager, manager: manager,
} }
@@ -72,7 +72,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
return return
} }
service := new(reverseproxy.Service) service := new(rpservice.Service)
service.FromAPIRequest(&req, userAuth.AccountId) service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil { if err = service.Validate(); err != nil {
@@ -130,7 +130,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
return return
} }
service := new(reverseproxy.Service) service := new(rpservice.Service)
service.ID = serviceID service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId) service.FromAPIRequest(&req, userAuth.AccountId)

View File

@@ -0,0 +1,74 @@
package manager
import (
"context"
"sync"
log "github.com/sirupsen/logrus"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/shared/management/proto"
)
// GRPCProxyController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
type GRPCProxyController struct {
proxyGRPCServer *nbgrpc.ProxyServiceServer
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
}
// NewGRPCProxyController creates a new GRPCProxyController.
func NewGRPCProxyController(proxyGRPCServer *nbgrpc.ProxyServiceServer) *GRPCProxyController {
return &GRPCProxyController{
proxyGRPCServer: proxyGRPCServer,
}
}
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
func (c *GRPCProxyController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
}
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
func (c *GRPCProxyController) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
return c.proxyGRPCServer.GetOIDCValidationConfig()
}
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
func (c *GRPCProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
return nil
}
// UnregisterProxyFromCluster removes a proxy from a cluster.
func (c *GRPCProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
if clusterAddr == "" {
return nil
}
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
}
return nil
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCProxyController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)
if !ok {
return nil
}
var proxies []string
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
proxies = append(proxies, key.(string))
return true
})
return proxies
}

View File

@@ -7,9 +7,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
@@ -26,26 +25,26 @@ type ClusterDeriver interface {
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
} }
type managerImpl struct { type Manager struct {
store store.Store store store.Store
accountManager account.Manager accountManager account.Manager
permissionsManager permissions.Manager permissionsManager permissions.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer proxyController service.ProxyController
clusterDeriver ClusterDeriver clusterDeriver ClusterDeriver
} }
// NewManager creates a new service manager. // NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager { func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController service.ProxyController, clusterDeriver ClusterDeriver) *Manager {
return &managerImpl{ return &Manager{
store: store, store: store,
accountManager: accountManager, accountManager: accountManager,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
proxyGRPCServer: proxyGRPCServer, proxyController: proxyController,
clusterDeriver: clusterDeriver, clusterDeriver: clusterDeriver,
} }
} }
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil { if err != nil {
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
@@ -69,34 +68,34 @@ func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID stri
return services, nil return services, nil
} }
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error { func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
for _, target := range service.Targets { for _, target := range s.Targets {
switch target.TargetType { switch target.TargetType {
case reverseproxy.TargetTypePeer: case service.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err) log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder target.Host = unknownHostPlaceholder
continue continue
} }
target.Host = peer.IP.String() target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost: case service.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err) log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder target.Host = unknownHostPlaceholder
continue continue
} }
target.Host = resource.Prefix.Addr().String() target.Host = resource.Prefix.Addr().String()
case reverseproxy.TargetTypeDomain: case service.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err) log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
target.Host = unknownHostPlaceholder target.Host = unknownHostPlaceholder
continue continue
} }
target.Host = resource.Domain target.Host = resource.Domain
case reverseproxy.TargetTypeSubnet: case service.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource // For subnets we do not do any lookups on the resource
default: default:
return fmt.Errorf("unknown target type: %s", target.TargetType) return fmt.Errorf("unknown target type: %s", target.TargetType)
@@ -105,7 +104,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string,
return nil return nil
} }
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) { func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil { if err != nil {
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
@@ -126,7 +125,7 @@ func (m *managerImpl) GetService(ctx context.Context, accountID, userID, service
return service, nil return service, nil
} }
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil { if err != nil {
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
@@ -135,29 +134,29 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil { if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
return nil, err return nil, err
} }
if err := m.persistNewService(ctx, accountID, service); err != nil { if err := m.persistNewService(ctx, accountID, s); err != nil {
return nil, err return nil, err
} }
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta()) m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service) err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
} }
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID) m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil return s, nil
} }
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error { func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
if m.clusterDeriver != nil { if m.clusterDeriver != nil {
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil { if err != nil {
@@ -184,7 +183,7 @@ func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID
return nil return nil
} }
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error { func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil { if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
return err return err
@@ -202,7 +201,7 @@ func (m *managerImpl) persistNewService(ctx context.Context, accountID string, s
}) })
} }
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error { func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain) existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
if err != nil { if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
@@ -218,7 +217,7 @@ func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction stor
return nil return nil
} }
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil { if err != nil {
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
@@ -242,7 +241,7 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
} }
m.sendServiceUpdateNotifications(service, updateInfo) m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
m.accountManager.UpdateAccountPeers(ctx, accountID) m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil return service, nil
@@ -254,7 +253,7 @@ type serviceUpdateInfo struct {
serviceEnabledChanged bool serviceEnabledChanged bool
} }
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) { func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
var updateInfo serviceUpdateInfo var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -292,7 +291,7 @@ func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string
return &updateInfo, err return &updateInfo, err
} }
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error { func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil { if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
return err return err
} }
@@ -309,7 +308,7 @@ func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.
return nil return nil
} }
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) { func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled && if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" { service.Auth.PasswordAuth.Password == "" {
@@ -323,40 +322,40 @@ func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reve
} }
} }
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) { func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
service.Meta = existingService.Meta service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey service.SessionPublicKey = existingService.SessionPublicKey
} }
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) { func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() oidcCfg := m.proxyController.GetOIDCValidationConfig()
switch { switch {
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster: case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
case !service.Enabled && updateInfo.serviceEnabledChanged: case !s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
case service.Enabled && updateInfo.serviceEnabledChanged: case s.Enabled && updateInfo.serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
default: default:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
} }
} }
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account. // validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error { func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
for _, target := range targets { for _, target := range targets {
switch target.TargetType { switch target.TargetType {
case reverseproxy.TargetTypePeer: case service.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
} }
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
} }
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain: case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
@@ -368,7 +367,7 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil return nil
} }
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil { if err != nil {
return status.NewPermissionValidationError(err) return status.NewPermissionValidationError(err)
@@ -377,10 +376,10 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
var service *reverseproxy.Service var s *service.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil { if err != nil {
return err return err
} }
@@ -395,9 +394,9 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return err return err
} }
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta()) m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID) m.accountManager.UpdateAccountPeers(ctx, accountID)
@@ -406,7 +405,7 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time. // SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
// Call this when receiving a gRPC notification that the certificate was issued. // Call this when receiving a gRPC notification that the certificate was issued.
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil { if err != nil {
@@ -424,7 +423,7 @@ func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, ser
} }
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.) // SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error { func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil { if err != nil {
@@ -441,42 +440,42 @@ func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string
}) })
} }
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error { func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get service: %w", err) return fmt.Errorf("failed to get service: %w", err)
} }
err = m.replaceHostByLookup(ctx, accountID, service) err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil { if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
} }
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID) m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error { func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get services: %w", err) return fmt.Errorf("failed to get services: %w", err)
} }
for _, service := range services { for _, s := range services {
err = m.replaceHostByLookup(ctx, accountID, service) err = m.replaceHostByLookup(ctx, accountID, s)
if err != nil { if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
} }
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
} }
return nil return nil
} }
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone) services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err) return nil, fmt.Errorf("failed to get services: %w", err)
@@ -492,7 +491,7 @@ func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Se
return services, nil return services, nil
} }
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) { func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err) return nil, fmt.Errorf("failed to get service: %w", err)
@@ -506,7 +505,7 @@ func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID s
return service, nil return service, nil
} }
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err) return nil, fmt.Errorf("failed to get services: %w", err)
@@ -522,7 +521,7 @@ func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string)
return services, nil return services, nil
} }
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {

View File

@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -20,13 +20,13 @@ func TestInitializeServiceForCreate(t *testing.T) {
accountID := "test-account" accountID := "test-account"
t.Run("successful initialization without cluster deriver", func(t *testing.T) { t.Run("successful initialization without cluster deriver", func(t *testing.T) {
mgr := &managerImpl{ mgr := &Manager{
clusterDeriver: nil, clusterDeriver: nil,
} }
service := &reverseproxy.Service{ service := &rpservice.Service{
Domain: "example.com", Domain: "example.com",
Auth: reverseproxy.AuthConfig{}, Auth: rpservice.AuthConfig{},
} }
err := mgr.initializeServiceForCreate(ctx, accountID, service) err := mgr.initializeServiceForCreate(ctx, accountID, service)
@@ -40,12 +40,12 @@ func TestInitializeServiceForCreate(t *testing.T) {
}) })
t.Run("verifies session keys are different", func(t *testing.T) { t.Run("verifies session keys are different", func(t *testing.T) {
mgr := &managerImpl{ mgr := &Manager{
clusterDeriver: nil, clusterDeriver: nil,
} }
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}} service1 := &rpservice.Service{Domain: "test1.com", Auth: rpservice.AuthConfig{}}
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}} service2 := &rpservice.Service{Domain: "test2.com", Auth: rpservice.AuthConfig{}}
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1) err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2) err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
@@ -87,7 +87,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) { setupMock: func(ms *store.MockStore) {
ms.EXPECT(). ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com"). GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil) Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
}, },
expectedError: true, expectedError: true,
errorType: status.AlreadyExists, errorType: status.AlreadyExists,
@@ -99,7 +99,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) { setupMock: func(ms *store.MockStore) {
ms.EXPECT(). ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com"). GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil) Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
}, },
expectedError: false, expectedError: false,
}, },
@@ -110,7 +110,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) { setupMock: func(ms *store.MockStore) {
ms.EXPECT(). ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com"). GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil) Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
}, },
expectedError: true, expectedError: true,
errorType: status.AlreadyExists, errorType: status.AlreadyExists,
@@ -136,7 +136,7 @@ func TestCheckDomainAvailable(t *testing.T) {
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
tt.setupMock(mockStore) tt.setupMock(mockStore)
mgr := &managerImpl{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID) err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
if tt.expectedError { if tt.expectedError {
@@ -166,7 +166,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
GetServiceByDomain(ctx, accountID, ""). GetServiceByDomain(ctx, accountID, "").
Return(nil, status.Errorf(status.NotFound, "not found")) Return(nil, status.Errorf(status.NotFound, "not found"))
mgr := &managerImpl{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "") err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
assert.NoError(t, err) assert.NoError(t, err)
@@ -179,9 +179,9 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT(). mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "test.com"). GetServiceByDomain(ctx, accountID, "test.com").
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil) Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
mgr := &managerImpl{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "") err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
assert.Error(t, err) assert.Error(t, err)
@@ -199,7 +199,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
GetServiceByDomain(ctx, accountID, "nil.com"). GetServiceByDomain(ctx, accountID, "nil.com").
Return(nil, nil) Return(nil, nil)
mgr := &managerImpl{} mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "") err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
assert.NoError(t, err) assert.NoError(t, err)
@@ -215,10 +215,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{ service := &rpservice.Service{
ID: "service-123", ID: "service-123",
Domain: "new.com", Domain: "new.com",
Targets: []*reverseproxy.Target{}, Targets: []*rpservice.Target{},
} }
// Mock ExecuteInTransaction to execute the function immediately // Mock ExecuteInTransaction to execute the function immediately
@@ -237,7 +237,7 @@ func TestPersistNewService(t *testing.T) {
return fn(txMock) return fn(txMock)
}) })
mgr := &managerImpl{store: mockStore} mgr := &Manager{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service) err := mgr.persistNewService(ctx, accountID, service)
assert.NoError(t, err) assert.NoError(t, err)
@@ -248,10 +248,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl) mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{ service := &rpservice.Service{
ID: "service-123", ID: "service-123",
Domain: "existing.com", Domain: "existing.com",
Targets: []*reverseproxy.Target{}, Targets: []*rpservice.Target{},
} }
mockStore.EXPECT(). mockStore.EXPECT().
@@ -260,12 +260,12 @@ func TestPersistNewService(t *testing.T) {
txMock := store.NewMockStore(ctrl) txMock := store.NewMockStore(ctrl)
txMock.EXPECT(). txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "existing.com"). GetServiceByDomain(ctx, accountID, "existing.com").
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil) Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
return fn(txMock) return fn(txMock)
}) })
mgr := &managerImpl{store: mockStore} mgr := &Manager{store: mockStore}
err := mgr.persistNewService(ctx, accountID, service) err := mgr.persistNewService(ctx, accountID, service)
require.Error(t, err) require.Error(t, err)
@@ -275,21 +275,21 @@ func TestPersistNewService(t *testing.T) {
}) })
} }
func TestPreserveExistingAuthSecrets(t *testing.T) { func TestPreserveExistingAuthSecrets(t *testing.T) {
mgr := &managerImpl{} mgr := &Manager{}
t.Run("preserve password when empty", func(t *testing.T) { t.Run("preserve password when empty", func(t *testing.T) {
existing := &reverseproxy.Service{ existing := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{ PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true, Enabled: true,
Password: "hashed-password", Password: "hashed-password",
}, },
}, },
} }
updated := &reverseproxy.Service{ updated := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{ PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true, Enabled: true,
Password: "", Password: "",
}, },
@@ -302,18 +302,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
}) })
t.Run("preserve pin when empty", func(t *testing.T) { t.Run("preserve pin when empty", func(t *testing.T) {
existing := &reverseproxy.Service{ existing := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{ PinAuth: &rpservice.PINAuthConfig{
Enabled: true, Enabled: true,
Pin: "hashed-pin", Pin: "hashed-pin",
}, },
}, },
} }
updated := &reverseproxy.Service{ updated := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{ PinAuth: &rpservice.PINAuthConfig{
Enabled: true, Enabled: true,
Pin: "", Pin: "",
}, },
@@ -326,18 +326,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
}) })
t.Run("do not preserve when password is provided", func(t *testing.T) { t.Run("do not preserve when password is provided", func(t *testing.T) {
existing := &reverseproxy.Service{ existing := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{ PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true, Enabled: true,
Password: "old-password", Password: "old-password",
}, },
}, },
} }
updated := &reverseproxy.Service{ updated := &rpservice.Service{
Auth: reverseproxy.AuthConfig{ Auth: rpservice.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{ PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true, Enabled: true,
Password: "new-password", Password: "new-password",
}, },
@@ -352,10 +352,10 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
} }
func TestPreserveServiceMetadata(t *testing.T) { func TestPreserveServiceMetadata(t *testing.T) {
mgr := &managerImpl{} mgr := &Manager{}
existing := &reverseproxy.Service{ existing := &rpservice.Service{
Meta: reverseproxy.ServiceMeta{ Meta: rpservice.ServiceMeta{
CertificateIssuedAt: time.Now(), CertificateIssuedAt: time.Now(),
Status: "active", Status: "active",
}, },
@@ -363,7 +363,7 @@ func TestPreserveServiceMetadata(t *testing.T) {
SessionPublicKey: "public-key", SessionPublicKey: "public-key",
} }
updated := &reverseproxy.Service{ updated := &rpservice.Service{
Domain: "updated.com", Domain: "updated.com",
} }

View File

@@ -1,4 +1,4 @@
package reverseproxy package service
import ( import (
"errors" "errors"
@@ -26,15 +26,15 @@ const (
Delete Operation = "delete" Delete Operation = "delete"
) )
type ProxyStatus string type Status string
const ( const (
StatusPending ProxyStatus = "pending" StatusPending Status = "pending"
StatusActive ProxyStatus = "active" StatusActive Status = "active"
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created" StatusTunnelNotCreated Status = "tunnel_not_created"
StatusCertificatePending ProxyStatus = "certificate_pending" StatusCertificatePending Status = "certificate_pending"
StatusCertificateFailed ProxyStatus = "certificate_failed" StatusCertificateFailed Status = "certificate_failed"
StatusError ProxyStatus = "error" StatusError Status = "error"
TargetTypePeer = "peer" TargetTypePeer = "peer"
TargetTypeHost = "host" TargetTypeHost = "host"

View File

@@ -1,4 +1,4 @@
package reverseproxy package service
import ( import (
"errors" "errors"

View File

@@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler { func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler { return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies)
if err != nil { if err != nil {
log.Fatalf("failed to create API handler: %v", err) log.Fatalf("failed to create API handler: %v", err)
} }
@@ -134,7 +134,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if s.Config.HttpConfig.LetsEncryptDomain != "" { if s.Config.HttpConfig.LetsEncryptDomain != "" {
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
if err != nil { if err != nil {
log.Fatalf("failed to create certificate manager: %v", err) log.Fatalf("failed to create certificate service: %v", err)
} }
transportCredentials := credentials.NewTLS(certManager.TLSConfig()) transportCredentials := credentials.NewTLS(certManager.TLSConfig())
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials)) gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
@@ -163,9 +163,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager()) proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
proxyService.SetProxyManager(s.ReverseProxyManager()) proxyService.SetProxyManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController())
}) })
return proxyService return proxyService
}) })
@@ -188,7 +189,10 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig {
func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore {
return Create(s, func() *nbgrpc.OneTimeTokenStore { return Create(s, func() *nbgrpc.OneTimeTokenStore {
tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100)
if err != nil {
log.Fatalf("failed to create proxy token store: %v", err)
}
log.Info("One-time token store initialized for proxy authentication") log.Info("One-time token store initialized for proxy authentication")
return tokenStore return tokenStore
}) })

View File

@@ -6,6 +6,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -106,6 +108,12 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
}) })
} }
func (s *BaseServer) ServiceProxyController() service.ProxyController {
return Create(s, func() service.ProxyController {
return nbreverseproxy.NewGRPCProxyController(s.ReverseProxyGRPCServer())
})
}
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
return Create(s, func() *server.AccountRequestBuffer { return Create(s, func() *server.AccountRequestBuffer {
return server.NewAccountRequestBuffer(context.Background(), s.Store()) return server.NewAccountRequestBuffer(context.Background(), s.Store())

View File

@@ -8,9 +8,11 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
@@ -99,11 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager { return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
if err != nil { if err != nil {
log.Fatalf("failed to create account manager: %v", err) log.Fatalf("failed to create account service: %v", err)
} }
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
accountManager.SetServiceManager(s.ReverseProxyManager()) accountManager.SetServiceManager(s.ServiceManager())
}) })
return accountManager return accountManager
@@ -114,28 +116,28 @@ func (s *BaseServer) IdpManager() idp.Manager {
return Create(s, func() idp.Manager { return Create(s, func() idp.Manager {
var idpManager idp.Manager var idpManager idp.Manager
var err error var err error
// Use embedded IdP manager if embedded Dex is configured and enabled. // Use embedded IdP service if embedded Dex is configured and enabled.
// Legacy IdpManager won't be used anymore even if configured. // Legacy IdpManager won't be used anymore even if configured.
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled { if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics()) idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
if err != nil { if err != nil {
log.Fatalf("failed to create embedded IDP manager: %v", err) log.Fatalf("failed to create embedded IDP service: %v", err)
} }
return idpManager return idpManager
} }
// Fall back to external IdP manager // Fall back to external IdP service
if s.Config.IdpManagerConfig != nil { if s.Config.IdpManagerConfig != nil {
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
if err != nil { if err != nil {
log.Fatalf("failed to create IDP manager: %v", err) log.Fatalf("failed to create IDP service: %v", err)
} }
} }
return idpManager return idpManager
}) })
} }
// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil // OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider { func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled { if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
return nil return nil
@@ -162,7 +164,7 @@ func (s *BaseServer) GroupsManager() groups.Manager {
func (s *BaseServer) ResourcesManager() resources.Manager { func (s *BaseServer) ResourcesManager() resources.Manager {
return Create(s, func() resources.Manager { return Create(s, func() resources.Manager {
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager()) return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
}) })
} }
@@ -190,15 +192,21 @@ func (s *BaseServer) RecordsManager() records.Manager {
}) })
} }
func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager { func (s *BaseServer) ServiceManager() service.Manager {
return Create(s, func() reverseproxy.Manager { return Create(s, func() service.Manager {
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager()) return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
})
}
func (s *BaseServer) ProxyManager() proxy.Manager {
return Create(s, func() proxy.Manager {
return proxymanager.NewManager(s.Store())
}) })
} }
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
return Create(s, func() *manager.Manager { return Create(s, func() *manager.Manager {
m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager()) m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager())
return &m return &m
}) })
} }

View File

@@ -1,202 +0,0 @@
package grpc
import (
"fmt"
"net/netip"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache cache.DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &cache.DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &cache.DNSConfigCache{}
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
}
}
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}
func TestBuildJWTConfig_Audiences(t *testing.T) {
tests := []struct {
name string
authAudience string
cliAuthAudience string
expectedAudiences []string
expectedAudience string
}{
{
name: "only_auth_audience",
authAudience: "dashboard-aud",
cliAuthAudience: "",
expectedAudiences: []string{"dashboard-aud"},
expectedAudience: "dashboard-aud",
},
{
name: "both_audiences_different",
authAudience: "dashboard-aud",
cliAuthAudience: "cli-aud",
expectedAudiences: []string{"dashboard-aud", "cli-aud"},
expectedAudience: "cli-aud",
},
{
name: "both_audiences_same",
authAudience: "same-aud",
cliAuthAudience: "same-aud",
expectedAudiences: []string{"same-aud"},
expectedAudience: "same-aud",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
config := &nbconfig.HttpServerConfig{
AuthIssuer: "https://issuer.example.com",
AuthAudience: tc.authAudience,
CLIAuthAudience: tc.cliAuthAudience,
}
result := buildJWTConfig(config, nil)
assert.NotNil(t, result)
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
//nolint:staticcheck // SA1019: Testing backwards compatibility - Audience field must still be populated
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
})
}
}

View File

@@ -1,276 +0,0 @@
package grpc
import (
"hash/fnv"
"math"
"math/rand"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/suite"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func testAdvancedCfg() *lfConfig {
return &lfConfig{
reconnThreshold: 50 * time.Millisecond,
baseBlockDuration: 100 * time.Millisecond,
reconnLimitForBan: 3,
metaChangeLimit: 2,
}
}
type LoginFilterTestSuite struct {
suite.Suite
filter *loginFilter
}
func (s *LoginFilterTestSuite) SetupTest() {
s.filter = newLoginFilterWithCfg(testAdvancedCfg())
}
func TestLoginFilterTestSuite(t *testing.T) {
suite.Run(t, new(LoginFilterTestSuite))
}
func (s *LoginFilterTestSuite) TestFirstLoginIsAlwaysAllowed() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(1, s.filter.logged[pubKey].sessionCounter)
}
func (s *LoginFilterTestSuite) TestFlappingSameHashTriggersBan() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.False(s.filter.allowLogin(pubKey, meta))
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanDurationIncreasesExponentially() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
limit := s.filter.cfg.reconnLimitForBan
baseBan := s.filter.cfg.baseBlockDuration
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.Require().Contains(s.filter.logged, pubKey)
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(1, s.filter.logged[pubKey].banLevel)
firstBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
s.InDelta(baseBan, firstBanDuration, float64(time.Millisecond))
s.filter.logged[pubKey].banExpiresAt = time.Now().Add(-time.Second)
s.filter.logged[pubKey].isBanned = false
for i := 0; i <= limit; i++ {
s.filter.addLogin(pubKey, meta)
}
s.True(s.filter.logged[pubKey].isBanned)
s.Equal(2, s.filter.logged[pubKey].banLevel)
secondBanDuration := s.filter.logged[pubKey].banExpiresAt.Sub(s.filter.logged[pubKey].lastSeen)
// nolint
expectedSecondDuration := time.Duration(float64(baseBan) * math.Pow(2, 1))
s.InDelta(expectedSecondDuration, secondBanDuration, float64(time.Millisecond))
}
func (s *LoginFilterTestSuite) TestPeerIsAllowedAfterBanExpires() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
isBanned: true,
banExpiresAt: time.Now().Add(-(s.filter.cfg.baseBlockDuration + time.Second)),
}
s.True(s.filter.allowLogin(pubKey, meta))
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.False(s.filter.logged[pubKey].isBanned)
}
func (s *LoginFilterTestSuite) TestBanLevelResetsAfterGoodBehavior() {
pubKey := "PUB_KEY_A"
meta := uint64(1)
s.filter.logged[pubKey] = &peerState{
currentHash: meta,
banLevel: 3,
lastSeen: time.Now().Add(-3 * s.filter.cfg.baseBlockDuration),
}
s.filter.addLogin(pubKey, meta)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(0, s.filter.logged[pubKey].banLevel)
}
func (s *LoginFilterTestSuite) TestFlappingDifferentHashesTriggersBlock() {
pubKey := "PUB_KEY_A"
limit := s.filter.cfg.metaChangeLimit
for i := range limit {
s.filter.addLogin(pubKey, uint64(i+1))
}
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(limit, s.filter.logged[pubKey].metaChangeCounter)
isAllowed := s.filter.allowLogin(pubKey, uint64(limit+1))
s.False(isAllowed, "should block new meta hash after limit is reached")
}
func (s *LoginFilterTestSuite) TestMetaChangeIsAllowedAfterWindowResets() {
pubKey := "PUB_KEY_A"
meta1 := uint64(1)
meta2 := uint64(2)
meta3 := uint64(3)
s.filter.addLogin(pubKey, meta1)
s.filter.addLogin(pubKey, meta2)
s.Require().Contains(s.filter.logged, pubKey)
s.Equal(s.filter.cfg.metaChangeLimit, s.filter.logged[pubKey].metaChangeCounter)
s.False(s.filter.allowLogin(pubKey, meta3), "should be blocked inside window")
s.filter.logged[pubKey].metaChangeWindowStart = time.Now().Add(-(s.filter.cfg.reconnThreshold + time.Second))
s.True(s.filter.allowLogin(pubKey, meta3), "should be allowed after window expires")
s.filter.addLogin(pubKey, meta3)
s.Equal(1, s.filter.logged[pubKey].metaChangeCounter, "meta change counter should reset")
}
func BenchmarkHashingMethods(b *testing.B) {
meta := nbpeer.PeerSystemMeta{
WtVersion: "1.25.1",
OSVersion: "Ubuntu 22.04.3 LTS",
KernelVersion: "5.15.0-76-generic",
Hostname: "prod-server-database-01",
SystemSerialNumber: "PC-1234567890",
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
}
pubip := "8.8.8.8"
var resultString string
var resultUint uint64
b.Run("BuilderString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = builderString(meta, pubip)
}
})
b.Run("FnvHashToString", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = fnvHashToString(meta, pubip)
}
})
b.Run("FnvHashToUint64 - used", func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultUint = metaHash(meta, pubip)
}
})
_ = resultString
_ = resultUint
}
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
h := fnv.New64a()
if len(meta.NetworkAddresses) != 0 {
for _, na := range meta.NetworkAddresses {
h.Write([]byte(na.Mac))
}
}
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return strconv.FormatUint(h.Sum64(), 16)
}
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
mac := getMacAddress(meta.NetworkAddresses)
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
len(pubip) + len(mac) + 6
var b strings.Builder
b.Grow(estimatedSize)
b.WriteString(meta.WtVersion)
b.WriteByte('|')
b.WriteString(meta.OSVersion)
b.WriteByte('|')
b.WriteString(meta.KernelVersion)
b.WriteByte('|')
b.WriteString(meta.Hostname)
b.WriteByte('|')
b.WriteString(meta.SystemSerialNumber)
b.WriteByte('|')
b.WriteString(pubip)
return b.String()
}
func getMacAddress(nas []nbpeer.NetworkAddress) string {
if len(nas) == 0 {
return ""
}
macs := make([]string, 0, len(nas))
for _, na := range nas {
macs = append(macs, na.Mac)
}
return strings.Join(macs, "/")
}
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
filter := newLoginFilterWithCfg(testAdvancedCfg())
numKeys := 100000
pubKeys := make([]string, numKeys)
for i := range numKeys {
pubKeys[i] = "PUB_KEY_" + strconv.Itoa(i)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
for pb.Next() {
key := pubKeys[r.Intn(numKeys)]
meta := r.Uint64()
if filter.allowLogin(key, meta) {
filter.addLogin(key, meta)
}
}
})
}

View File

@@ -1,28 +1,23 @@
package grpc package grpc
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/sha256"
"crypto/subtle" "crypto/subtle"
"encoding/base64" "encoding/base64"
"encoding/hex"
"encoding/json"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/eko/gocache/lib/v4/cache"
"github.com/eko/gocache/lib/v4/store"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbcache "github.com/netbirdio/netbird/management/server/cache"
) )
// OneTimeTokenStore manages short-lived, single-use authentication tokens
// for proxy-to-management RPC authentication. Tokens are generated when
// a service is created and must be used exactly once by the proxy
// to authenticate a subsequent RPC call.
type OneTimeTokenStore struct {
tokens map[string]*tokenMetadata
mu sync.RWMutex
cleanup *time.Ticker
cleanupDone chan struct{}
}
// tokenMetadata stores information about a one-time token
type tokenMetadata struct { type tokenMetadata struct {
ServiceID string ServiceID string
AccountID string AccountID string
@@ -30,20 +25,24 @@ type tokenMetadata struct {
CreatedAt time.Time CreatedAt time.Time
} }
// NewOneTimeTokenStore creates a new token store with automatic cleanup // OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
// of expired tokens. The cleanupInterval determines how often expired // Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
// tokens are removed from memory. type OneTimeTokenStore struct {
func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore { cache *cache.Cache[string]
store := &OneTimeTokenStore{ ctx context.Context
tokens: make(map[string]*tokenMetadata), }
cleanup: time.NewTicker(cleanupInterval),
cleanupDone: make(chan struct{}), // NewOneTimeTokenStore creates a token store with automatic backend selection
func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) {
cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn)
if err != nil {
return nil, fmt.Errorf("failed to create cache store: %w", err)
} }
// Start background cleanup goroutine return &OneTimeTokenStore{
go store.cleanupExpired() cache: cache.New[string](cacheStore),
ctx: ctx,
return store }, nil
} }
// GenerateToken creates a new cryptographically secure one-time token // GenerateToken creates a new cryptographically secure one-time token
@@ -52,25 +51,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore {
// //
// Returns the generated token string or an error if random generation fails. // Returns the generated token string or an error if random generation fails.
func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) { func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) {
// Generate 32 bytes (256 bits) of cryptographically secure random data
randomBytes := make([]byte, 32) randomBytes := make([]byte, 32)
if _, err := rand.Read(randomBytes); err != nil { if _, err := rand.Read(randomBytes); err != nil {
return "", fmt.Errorf("failed to generate random token: %w", err) return "", fmt.Errorf("failed to generate random token: %w", err)
} }
// Encode as URL-safe base64 for easy transmission in gRPC
token := base64.URLEncoding.EncodeToString(randomBytes) token := base64.URLEncoding.EncodeToString(randomBytes)
hashedToken := hashToken(token)
s.mu.Lock() metadata := &tokenMetadata{
defer s.mu.Unlock()
s.tokens[token] = &tokenMetadata{
ServiceID: serviceID, ServiceID: serviceID,
AccountID: accountID, AccountID: accountID,
ExpiresAt: time.Now().Add(ttl), ExpiresAt: time.Now().Add(ttl),
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
metadataJSON, err := json.Marshal(metadata)
if err != nil {
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
}
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
return "", fmt.Errorf("failed to store token: %w", err)
}
log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)", log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)",
serviceID, accountID, ttl) serviceID, accountID, ttl)
@@ -88,80 +92,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
// - Account ID doesn't match // - Account ID doesn't match
// - Reverse proxy ID doesn't match // - Reverse proxy ID doesn't match
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error { func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
s.mu.Lock() hashedToken := hashToken(token)
defer s.mu.Unlock()
metadata, exists := s.tokens[token] metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
if !exists { if err != nil {
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
serviceID, accountID)
return fmt.Errorf("invalid token") return fmt.Errorf("invalid token")
} }
// Check expiration metadata := &tokenMetadata{}
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
return fmt.Errorf("invalid token metadata")
}
if time.Now().After(metadata.ExpiresAt) { if time.Now().After(metadata.ExpiresAt) {
delete(s.tokens, token) log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)",
serviceID, accountID)
return fmt.Errorf("token expired") return fmt.Errorf("token expired")
} }
// Validate account ID using constant-time comparison (prevents timing attacks)
if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 { if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 {
log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID)
metadata.AccountID, accountID)
return fmt.Errorf("account ID mismatch") return fmt.Errorf("account ID mismatch")
} }
// Validate service ID using constant-time comparison
if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 { if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 {
log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID)
metadata.ServiceID, serviceID)
return fmt.Errorf("service ID mismatch") return fmt.Errorf("service ID mismatch")
} }
// Delete token immediately to enforce single-use if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
delete(s.tokens, token) log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
}
log.Infof("Token validated and consumed for proxy %s in account %s", log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID)
serviceID, accountID)
return nil return nil
} }
// cleanupExpired removes expired tokens in the background to prevent memory leaks func hashToken(token string) string {
func (s *OneTimeTokenStore) cleanupExpired() { hash := sha256.Sum256([]byte(token))
for { return hex.EncodeToString(hash[:])
select {
case <-s.cleanup.C:
s.mu.Lock()
now := time.Now()
removed := 0
for token, metadata := range s.tokens {
if now.After(metadata.ExpiresAt) {
delete(s.tokens, token)
removed++
}
}
if removed > 0 {
log.Debugf("Cleaned up %d expired one-time tokens", removed)
}
s.mu.Unlock()
case <-s.cleanupDone:
return
}
}
}
// Close stops the cleanup goroutine and releases resources
func (s *OneTimeTokenStore) Close() {
s.cleanup.Stop()
close(s.cleanupDone)
}
// GetTokenCount returns the current number of tokens in the store (for debugging/metrics)
func (s *OneTimeTokenStore) GetTokenCount() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.tokens)
} }

View File

@@ -24,8 +24,9 @@ import (
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
@@ -58,9 +59,6 @@ type ProxyServiceServer struct {
// Map of connected proxies: proxy_id -> proxy connection // Map of connected proxies: proxy_id -> proxy connection
connectedProxies sync.Map connectedProxies sync.Map
// Map of cluster address -> set of proxy IDs
clusterProxies sync.Map
// Channel for broadcasting reverse proxy updates to all proxies // Channel for broadcasting reverse proxy updates to all proxies
updatesChan chan *proto.ProxyMapping updatesChan chan *proto.ProxyMapping
@@ -68,7 +66,13 @@ type ProxyServiceServer struct {
accessLogManager accesslogs.Manager accessLogManager accesslogs.Manager
// Manager for reverse proxy operations // Manager for reverse proxy operations
reverseProxyManager reverseproxy.Manager serviceManager rpservice.Manager
// ProxyController for service updates and cluster management
proxyController rpservice.ProxyController
// Manager for proxy connections
proxyManager proxy.Manager
// Manager for peers // Manager for peers
peersManager peers.Manager peersManager peers.Manager
@@ -107,7 +111,7 @@ type proxyConnection struct {
} }
// NewProxyServiceServer creates a new proxy service server. // NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer { func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{ s := &ProxyServiceServer{
updatesChan: make(chan *proto.ProxyMapping, 100), updatesChan: make(chan *proto.ProxyMapping, 100),
@@ -116,9 +120,11 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
tokenStore: tokenStore, tokenStore: tokenStore,
peersManager: peersManager, peersManager: peersManager,
usersManager: usersManager, usersManager: usersManager,
proxyManager: proxyMgr,
pkceCleanupCancel: cancel, pkceCleanupCancel: cancel,
} }
go s.cleanupPKCEVerifiers(ctx) go s.cleanupPKCEVerifiers(ctx)
go s.cleanupStaleProxies(ctx)
return s return s
} }
@@ -142,13 +148,33 @@ func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) {
} }
} }
// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes
func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil {
log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err)
}
}
}
}
// Close stops background goroutines. // Close stops background goroutines.
func (s *ProxyServiceServer) Close() { func (s *ProxyServiceServer) Close() {
s.pkceCleanupCancel() s.pkceCleanupCancel()
} }
func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) { func (s *ProxyServiceServer) SetProxyManager(manager rpservice.Manager) {
s.reverseProxyManager = manager s.serviceManager = manager
}
func (s *ProxyServiceServer) SetProxyController(proxyController rpservice.ProxyController) {
s.proxyController = proxyController
} }
// GetMappingUpdate handles the control stream with proxy clients // GetMappingUpdate handles the control stream with proxy clients
@@ -183,7 +209,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
} }
s.connectedProxies.Store(proxyID, conn) s.connectedProxies.Store(proxyID, conn)
s.addToCluster(conn.address, proxyID) if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
}
// Register proxy in database
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
}
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"proxy_id": proxyID, "proxy_id": proxyID,
"address": proxyAddress, "address": proxyAddress,
@@ -191,8 +225,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
"total_proxies": len(s.GetConnectedProxies()), "total_proxies": len(s.GetConnectedProxies()),
}).Info("Proxy registered in cluster") }).Info("Proxy registered in cluster")
defer func() { defer func() {
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
}
s.connectedProxies.Delete(proxyID) s.connectedProxies.Delete(proxyID)
s.removeFromCluster(conn.address, proxyID) if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
}
cancel() cancel()
log.Infof("Proxy %s disconnected", proxyID) log.Infof("Proxy %s disconnected", proxyID)
}() }()
@@ -204,6 +245,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
errChan := make(chan error, 2) errChan := make(chan error, 2)
go s.sender(conn, errChan) go s.sender(conn, errChan)
// Start heartbeat goroutine
go s.heartbeat(connCtx, proxyID)
select { select {
case err := <-errChan: case err := <-errChan:
return fmt.Errorf("send update to proxy %s: %w", proxyID, err) return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
@@ -212,10 +256,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
} }
} }
// heartbeat updates the proxy's last_seen timestamp every minute
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil {
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
}
case <-ctx.Done():
return
}
}
}
// sendSnapshot sends the initial snapshot of services to the connecting proxy. // sendSnapshot sends the initial snapshot of services to the connecting proxy.
// Only services matching the proxy's cluster address are sent. // Only services matching the proxy's cluster address are sent.
func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error {
services, err := s.reverseProxyManager.GetGlobalServices(ctx) services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil { if err != nil {
return fmt.Errorf("get services from store: %w", err) return fmt.Errorf("get services from store: %w", err)
} }
@@ -224,7 +285,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
return fmt.Errorf("proxy address is invalid") return fmt.Errorf("proxy address is invalid")
} }
var filtered []*reverseproxy.Service var filtered []*rpservice.Service
for _, service := range services { for _, service := range services {
if !service.Enabled { if !service.Enabled {
continue continue
@@ -259,7 +320,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{ Mapping: []*proto.ProxyMapping{
service.ToProtoMapping( service.ToProtoMapping(
reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy. rpservice.Create, // Initial snapshot, all records are "new" for the proxy.
token, token,
s.GetOIDCValidationConfig(), s.GetOIDCValidationConfig(),
), ),
@@ -393,61 +454,43 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
return urls return urls
} }
// addToCluster registers a proxy in a cluster.
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
}
// removeFromCluster removes a proxy from a cluster.
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
if clusterAddr == "" {
return
}
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
proxySet.(*sync.Map).Delete(proxyID)
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
}
}
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster. // SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility). // If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
// For create/update operations a unique one-time auth token is generated per // For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management. // proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) { func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
if clusterAddr == "" { if clusterAddr == "" {
s.SendServiceUpdate(update) s.SendServiceUpdate(update)
return return
} }
proxySet, ok := s.clusterProxies.Load(clusterAddr) if s.proxyController == nil {
if !ok { log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr)
log.Debugf("No proxies connected for cluster %s", clusterAddr) return
}
proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr)
if len(proxyIDs) == 0 {
log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr)
return return
} }
log.Debugf("Sending service update to cluster %s", clusterAddr) log.Debugf("Sending service update to cluster %s", clusterAddr)
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool { for _, proxyID := range proxyIDs {
proxyID := key.(string)
if connVal, ok := s.connectedProxies.Load(proxyID); ok { if connVal, ok := s.connectedProxies.Load(proxyID); ok {
conn := connVal.(*proxyConnection) conn := connVal.(*proxyConnection)
msg := s.perProxyMessage(update, proxyID) msg := s.perProxyMessage(update, proxyID)
if msg == nil { if msg == nil {
return true continue
} }
select { select {
case conn.sendChan <- msg: case conn.sendChan <- msg:
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
default: default:
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
} }
} }
return true }
})
} }
// perProxyMessage returns a copy of update with a fresh one-time token for // perProxyMessage returns a copy of update with a fresh one-time token for
@@ -486,35 +529,8 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
} }
} }
// GetAvailableClusters returns information about all connected proxy clusters.
func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo {
clusterCounts := make(map[string]int)
s.clusterProxies.Range(func(key, value interface{}) bool {
clusterAddr := key.(string)
proxySet := value.(*sync.Map)
count := 0
proxySet.Range(func(_, _ interface{}) bool {
count++
return true
})
if count > 0 {
clusterCounts[clusterAddr] = count
}
return true
})
clusters := make([]ClusterInfo, 0, len(clusterCounts))
for addr, count := range clusterCounts {
clusters = append(clusters, ClusterInfo{
Address: addr,
ConnectedProxies: count,
})
}
return clusters
}
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get service from store: %v", err) log.WithContext(ctx).Debugf("failed to get service from store: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err)
@@ -533,7 +549,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen
}, nil }, nil
} }
func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) { func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) {
switch v := req.GetRequest().(type) { switch v := req.GetRequest().(type) {
case *proto.AuthenticateRequest_Pin: case *proto.AuthenticateRequest_Pin:
return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth) return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth)
@@ -544,7 +560,7 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto
} }
} }
func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) { func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled { if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID) log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID)
return false, "", "" return false, "", ""
@@ -558,7 +574,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri
return true, "pin-user", proxyauth.MethodPIN return true, "pin-user", proxyauth.MethodPIN
} }
func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) { func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) {
if auth == nil || !auth.Enabled { if auth == nil || !auth.Enabled {
log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID) log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID)
return false, "", "" return false, "", ""
@@ -580,7 +596,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err
} }
} }
func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) { func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) {
if !authenticated || service.SessionPrivateKey == "" { if !authenticated || service.SessionPrivateKey == "" {
return "", nil return "", nil
} }
@@ -620,7 +636,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
} }
if certificateIssued { if certificateIssued {
if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil { if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp") log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp")
return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err) return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err)
} }
@@ -632,7 +648,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
internalStatus := protoStatusToInternal(protoStatus) internalStatus := protoStatusToInternal(protoStatus)
if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil {
log.WithContext(ctx).WithError(err).Error("failed to update service status") log.WithContext(ctx).WithError(err).Error("failed to update service status")
return nil, status.Errorf(codes.Internal, "update service status: %v", err) return nil, status.Errorf(codes.Internal, "update service status: %v", err)
} }
@@ -647,22 +663,22 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se
} }
// protoStatusToInternal maps proto status to internal status // protoStatusToInternal maps proto status to internal status
func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus { func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status {
switch protoStatus { switch protoStatus {
case proto.ProxyStatus_PROXY_STATUS_PENDING: case proto.ProxyStatus_PROXY_STATUS_PENDING:
return reverseproxy.StatusPending return rpservice.StatusPending
case proto.ProxyStatus_PROXY_STATUS_ACTIVE: case proto.ProxyStatus_PROXY_STATUS_ACTIVE:
return reverseproxy.StatusActive return rpservice.StatusActive
case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED: case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED:
return reverseproxy.StatusTunnelNotCreated return rpservice.StatusTunnelNotCreated
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING: case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING:
return reverseproxy.StatusCertificatePending return rpservice.StatusCertificatePending
case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED: case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED:
return reverseproxy.StatusCertificateFailed return rpservice.StatusCertificateFailed
case proto.ProxyStatus_PROXY_STATUS_ERROR: case proto.ProxyStatus_PROXY_STATUS_ERROR:
return reverseproxy.StatusError return rpservice.StatusError
default: default:
return reverseproxy.StatusError return rpservice.StatusError
} }
} }
@@ -727,7 +743,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU
return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err)
} }
// Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection. // Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection.
services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId()) services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId())
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account services: %v", err) log.WithContext(ctx).Errorf("failed to get account services: %v", err)
return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err)
@@ -790,8 +806,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig {
// GetOIDCValidationConfig returns the OIDC configuration for token validation // GetOIDCValidationConfig returns the OIDC configuration for token validation
// in the format needed by ToProtoMapping. // in the format needed by ToProtoMapping.
func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig { func (s *ProxyServiceServer) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
return reverseproxy.OIDCValidationConfig{ return rpservice.OIDCValidationConfig{
Issuer: s.oidcConfig.Issuer, Issuer: s.oidcConfig.Issuer,
Audiences: []string{s.oidcConfig.Audience}, Audiences: []string{s.oidcConfig.Audience},
KeysLocation: s.oidcConfig.KeysLocation, KeysLocation: s.oidcConfig.KeysLocation,
@@ -850,12 +866,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL
// GenerateSessionToken creates a signed session JWT for the given domain and user. // GenerateSessionToken creates a signed session JWT for the given domain and user.
func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) {
// Find the service by domain to get its signing key // Find the service by domain to get its signing key
services, err := s.reverseProxyManager.GetGlobalServices(ctx) services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil { if err != nil {
return "", fmt.Errorf("get services: %w", err) return "", fmt.Errorf("get services: %w", err)
} }
var service *reverseproxy.Service var service *rpservice.Service
for _, svc := range services { for _, svc := range services {
if svc.Domain == domain { if svc.Domain == domain {
service = svc service = svc
@@ -921,8 +937,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain
return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain) return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain)
} }
func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID) services, err := s.serviceManager.GetAccountServices(ctx, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account services: %w", err) return nil, fmt.Errorf("get account services: %w", err)
} }
@@ -1043,8 +1059,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val
}, nil }, nil
} }
func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) { func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) {
services, err := s.reverseProxyManager.GetGlobalServices(ctx) services, err := s.serviceManager.GetGlobalServices(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("get services: %w", err) return nil, fmt.Errorf("get services: %w", err)
} }
@@ -1058,7 +1074,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri
return nil, fmt.Errorf("service not found for domain: %s", domain) return nil, fmt.Errorf("service not found for domain: %s", domain)
} }
func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error { func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error {
if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled { if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled {
return nil return nil
} }

View File

@@ -1,98 +0,0 @@
package grpc
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
)
func TestAuthFailureLimiter_NotLimitedInitially(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
assert.False(t, l.isLimited("192.168.1.1"), "new IP should not be rate limited")
}
func TestAuthFailureLimiter_LimitedAfterBurst(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
ip := "192.168.1.1"
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
assert.True(t, l.isLimited(ip), "IP should be limited after exhausting burst")
}
func TestAuthFailureLimiter_DifferentIPsIndependent(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure("192.168.1.1")
}
assert.True(t, l.isLimited("192.168.1.1"))
assert.False(t, l.isLimited("192.168.1.2"), "different IP should not be affected")
}
func TestAuthFailureLimiter_RecoveryOverTime(t *testing.T) {
l := newAuthFailureLimiterWithRate(rate.Limit(100)) // 100 tokens/sec for fast recovery
defer l.stop()
ip := "10.0.0.1"
// Exhaust burst
for i := 0; i < proxyAuthFailureBurst; i++ {
l.recordFailure(ip)
}
require.True(t, l.isLimited(ip))
// Wait for token replenishment
time.Sleep(50 * time.Millisecond)
assert.False(t, l.isLimited(ip), "should recover after tokens replenish")
}
func TestAuthFailureLimiter_Cleanup(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.mu.Lock()
require.Len(t, l.limiters, 1)
// Backdate the entry so it looks stale
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Empty(t, l.limiters, "stale entries should be cleaned up")
l.mu.Unlock()
}
func TestAuthFailureLimiter_CleanupKeepsFresh(t *testing.T) {
l := newAuthFailureLimiter()
defer l.stop()
l.recordFailure("10.0.0.1")
l.recordFailure("10.0.0.2")
l.mu.Lock()
// Only backdate one entry
l.limiters["10.0.0.1"].lastAccess = time.Now().Add(-proxyAuthLimiterTTL - time.Minute)
l.mu.Unlock()
l.cleanup()
l.mu.Lock()
assert.Len(t, l.limiters, 1, "only stale entries should be removed")
assert.Contains(t, l.limiters, "10.0.0.2")
l.mu.Unlock()
}

View File

@@ -1,381 +0,0 @@
package grpc
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/server/types"
)
type mockReverseProxyManager struct {
proxiesByAccount map[string][]*reverseproxy.Service
err error
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
if m.err != nil {
return nil, m.err
}
return m.proxiesByAccount[accountID], nil
}
func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return []*reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error {
return nil
}
func (m *mockReverseProxyManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
return nil
}
func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, reverseProxyID string) error {
return nil
}
func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
type mockUsersManager struct {
users map[string]*types.User
err error
}
func (m *mockUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
if m.err != nil {
return nil, m.err
}
user, ok := m.users[userID]
if !ok {
return nil, errors.New("user not found")
}
return user, nil
}
func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct {
name string
domain string
userID string
proxiesByAccount map[string][]*reverseproxy.Service
users map[string]*types.User
proxyErr error
userErr error
expectErr bool
expectErrMsg string
}{
{
name: "user not found",
domain: "app.example.com",
userID: "unknown-user",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
users: map[string]*types.User{},
expectErr: true,
expectErrMsg: "user not found",
},
{
name: "proxy not found in user's account",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "service not found",
},
{
name: "proxy exists in different account - not accessible",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account2": {{Domain: "app.example.com", AccountID: "account2"}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "service not found",
},
{
name: "no bearer auth configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth disabled - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "bearer auth enabled but no groups configured - same account allows access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
{
name: "user not in allowed groups",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group3", "group4"}},
},
expectErr: true,
expectErrMsg: "not in allowed groups",
},
{
name: "user in one of the allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group2", "group3"}},
},
expectErr: false,
},
{
name: "user in all allowed groups - allow access",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{
Domain: "app.example.com",
AccountID: "account1",
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"group1", "group2"},
},
},
}},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1", AutoGroups: []string{"group1", "group2", "group3"}},
},
expectErr: false,
},
{
name: "proxy manager error",
domain: "app.example.com",
userID: "user1",
proxiesByAccount: nil,
proxyErr: errors.New("database error"),
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: true,
expectErrMsg: "get account services",
},
{
name: "multiple proxies in account - finds correct one",
domain: "app2.example.com",
userID: "user1",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "app1.example.com", AccountID: "account1"},
{Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}},
{Domain: "app3.example.com", AccountID: "account1"},
},
},
users: map[string]*types.User{
"user1": {Id: "user1", AccountID: "account1"},
},
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.proxyErr,
},
usersManager: &mockUsersManager{
users: tt.users,
err: tt.userErr,
},
}
err := server.ValidateUserGroupAccess(context.Background(), tt.domain, tt.userID)
if tt.expectErr {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.expectErrMsg)
} else {
require.NoError(t, err)
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct {
name string
accountID string
domain string
proxiesByAccount map[string][]*reverseproxy.Service
err error
expectProxy bool
expectErr bool
}{
{
name: "proxy found",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {
{Domain: "other.example.com", AccountID: "account1"},
{Domain: "app.example.com", AccountID: "account1"},
},
},
expectProxy: true,
expectErr: false,
},
{
name: "proxy not found in account",
accountID: "account1",
domain: "unknown.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{
"account1": {{Domain: "app.example.com", AccountID: "account1"}},
},
expectProxy: false,
expectErr: true,
},
{
name: "empty proxy list for account",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: map[string][]*reverseproxy.Service{},
expectProxy: false,
expectErr: true,
},
{
name: "manager error",
accountID: "account1",
domain: "app.example.com",
proxiesByAccount: nil,
err: errors.New("database error"),
expectProxy: false,
expectErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := &ProxyServiceServer{
reverseProxyManager: &mockReverseProxyManager{
proxiesByAccount: tt.proxiesByAccount,
err: tt.err,
},
}
proxy, err := server.getAccountServiceByDomain(context.Background(), tt.accountID, tt.domain)
if tt.expectErr {
require.Error(t, err)
assert.Nil(t, proxy)
} else {
require.NoError(t, err)
require.NotNil(t, proxy)
assert.Equal(t, tt.domain, proxy.Domain)
}
})
}
}

View File

@@ -1,232 +0,0 @@
package grpc
import (
"crypto/rand"
"encoding/base64"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto"
)
// registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
ch := make(chan *proto.ProxyMapping, 10)
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
sendChan: ch,
}
s.connectedProxies.Store(proxyID, conn)
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
proxySet.(*sync.Map).Store(proxyID, struct{}{})
return ch
}
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
select {
case msg := <-ch:
return msg
case <-time.After(time.Second):
return nil
}
}
func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
const cluster = "proxy.example.com"
const numProxies = 3
channels := make([]chan *proto.ProxyMapping, numProxies)
for i := range numProxies {
id := "proxy-" + string(rune('a'+i))
channels[i] = registerFakeProxy(s, id, cluster)
}
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
Path: []*proto.PathMapping{
{Path: "/", Target: "http://10.0.0.1:8080/"},
},
}
s.SendServiceUpdateToCluster(update, cluster)
tokens := make([]string, numProxies)
for i, ch := range channels {
msg := drainChannel(ch)
require.NotNil(t, msg, "proxy %d should receive a message", i)
assert.Equal(t, update.Domain, msg.Domain)
assert.Equal(t, update.Id, msg.Id)
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
tokens[i] = msg.AuthToken
}
// All tokens must be unique
tokenSet := make(map[string]struct{})
for i, tok := range tokens {
_, exists := tokenSet[tok]
assert.False(t, exists, "proxy %d got duplicate token", i)
tokenSet[tok] = struct{}{}
}
// Each token must be independently consumable
for i, tok := range tokens {
err := tokenStore.ValidateAndConsume(tok, "account-1", "service-1")
assert.NoError(t, err, "proxy %d token should validate successfully", i)
}
}
func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
const cluster = "proxy.example.com"
ch1 := registerFakeProxy(s, "proxy-a", cluster)
ch2 := registerFakeProxy(s, "proxy-b", cluster)
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdateToCluster(update, cluster)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
// Delete operations should not generate tokens
assert.Empty(t, msg1.AuthToken)
assert.Empty(t, msg2.AuthToken)
// No tokens should have been created
assert.Equal(t, 0, tokenStore.GetTokenCount())
}
func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
tokenStore := NewOneTimeTokenStore(time.Hour)
defer tokenStore.Close()
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
}
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
update := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
s.SendServiceUpdate(update)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
assert.NotEmpty(t, msg1.AuthToken)
assert.NotEmpty(t, msg2.AuthToken)
assert.NotEqual(t, msg1.AuthToken, msg2.AuthToken, "tokens must be unique per proxy")
// Both tokens should validate
assert.NoError(t, tokenStore.ValidateAndConsume(msg1.AuthToken, "account-1", "service-1"))
assert.NoError(t, tokenStore.ValidateAndConsume(msg2.AuthToken, "account-1", "service-1"))
}
// generateState creates a state using the same format as GetOIDCURL.
func generateState(s *ProxyServiceServer, redirectURL string) string {
nonce := make([]byte, 16)
_, _ = rand.Read(nonce)
nonceB64 := base64.URLEncoding.EncodeToString(nonce)
payload := redirectURL + "|" + nonceB64
hmacSum := s.generateHMAC(payload)
return base64.URLEncoding.EncodeToString([]byte(redirectURL)) + "|" + nonceB64 + "|" + hmacSum
}
func TestOAuthState_NeverTheSame(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
redirectURL := "https://app.example.com/callback"
// Generate 100 states for the same redirect URL
states := make(map[string]bool)
for i := 0; i < 100; i++ {
state := generateState(s, redirectURL)
// State must have 3 parts: base64(url)|nonce|hmac
parts := strings.Split(state, "|")
require.Equal(t, 3, len(parts), "state must have 3 parts")
// State must be unique
require.False(t, states[state], "state %d is a duplicate", i)
states[state] = true
}
}
func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Old format had only 2 parts: base64(url)|hmac
s.pkceVerifiers.Store("base64url|hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("base64url|hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state format")
}
func TestValidateState_RejectsInvalidHMAC(t *testing.T) {
s := &ProxyServiceServer{
oidcConfig: ProxyOIDCConfig{
HMACKey: []byte("test-hmac-key"),
},
}
// Store with tampered HMAC
s.pkceVerifiers.Store("dGVzdA==|nonce|wrong-hmac", pkceEntry{verifier: "test", createdAt: time.Now()})
_, _, err := s.ValidateState("dGVzdA==|nonce|wrong-hmac")
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid state signature")
}

View File

@@ -1,108 +0,0 @@
package grpc
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
)
func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
testingServerKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate server wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testingClientKey, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Errorf("unable to generate client wg key for testing GetDeviceAuthorizationFlow, error: %v", err)
}
testCases := []struct {
name string
inputFlow *config.DeviceAuthorizationFlow
expectedFlow *mgmtProto.DeviceAuthorizationFlow
expectedErrFunc require.ErrorAssertionFunc
expectedErrMSG string
expectedComparisonFunc require.ComparisonAssertionFunc
expectedComparisonMSG string
}{
{
name: "Testing No Device Flow Config",
inputFlow: nil,
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Invalid Device Flow Provider Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "NoNe",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.Error,
expectedErrMSG: "should return error",
},
{
name: "Testing Full Device Flow Config",
inputFlow: &config.DeviceAuthorizationFlow{
Provider: "hosted",
ProviderConfig: config.ProviderConfig{
ClientID: "test",
},
},
expectedFlow: &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "test",
},
},
expectedErrFunc: require.NoError,
expectedErrMSG: "should not return error",
expectedComparisonFunc: require.Equal,
expectedComparisonMSG: "should match",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
mgmtServer := &Server{
secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey},
config: &config.Config{
DeviceAuthorizationFlow: testCase.inputFlow,
},
}
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
key, err := mgmtServer.secretsManager.GetWGKey()
require.NoError(t, err, "should be able to get server key")
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message)
require.NoError(t, err, "should be able to encrypt message")
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
context.TODO(),
&mgmtProto.EncryptedMessage{
WgPubKey: testingClientKey.PublicKey().String(),
Body: encryptedMSG,
},
)
testCase.expectedErrFunc(t, err, testCase.expectedErrMSG)
if testCase.expectedComparisonFunc != nil {
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
require.NoError(t, err, "should be able to decrypt")
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
testCase.expectedComparisonFunc(t, testCase.expectedFlow.ProviderConfig.ClientID, flowInfoResp.ProviderConfig.ClientID, testCase.expectedComparisonMSG)
}
})
}
}

View File

@@ -1,250 +0,0 @@
package grpc
import (
"context"
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"hash"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
var TurnTestHost = &config.Host{
Proto: config.UDP,
URI: "turn:turn.netbird.io:77777",
Username: "username",
Password: "",
}
func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
peersManager := update_channel.NewPeersUpdateManager(nil)
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
turnCredentials, err := tested.GenerateTurnToken()
require.NoError(t, err)
if turnCredentials.Payload == "" {
t.Errorf("expected generated TURN username not to be empty, got empty")
}
if turnCredentials.Signature == "" {
t.Errorf("expected generated TURN password not to be empty, got empty")
}
validateMAC(t, sha1.New, turnCredentials.Payload, turnCredentials.Signature, []byte(secret))
relayCredentials, err := tested.GenerateRelayToken()
require.NoError(t, err)
if relayCredentials.Payload == "" {
t.Errorf("expected generated relay payload not to be empty, got empty")
}
if relayCredentials.Signature == "" {
t.Errorf("expected generated relay signature not to be empty, got empty")
}
hashedSecret := sha256.Sum256([]byte(secret))
validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
}
func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
ttl := util.Duration{Duration: 2 * time.Second}
secret := "some_secret"
peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer"
updateChannel := peersManager.CreateChannel(context.Background(), peer)
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock()
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tested.SetupRefresh(ctx, "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in the turn cancel map, got not present")
}
if _, ok := tested.relayCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in the relay cancel map, got not present")
}
var updates []*network_map.UpdateMessage
loop:
for timeout := time.After(5 * time.Second); ; {
select {
case update := <-updateChannel:
updates = append(updates, update)
case <-timeout:
break loop
}
if len(updates) >= 2 {
break loop
}
}
if len(updates) < 2 {
t.Errorf("expecting at least 2 peer credentials updates, got %v", len(updates))
}
var turnUpdates, relayUpdates int
var firstTurnUpdate, secondTurnUpdate *proto.ProtectedHostConfig
var firstRelayUpdate, secondRelayUpdate *proto.RelayConfig
for _, update := range updates {
if turns := update.Update.GetNetbirdConfig().GetTurns(); len(turns) > 0 {
turnUpdates++
if turnUpdates == 1 {
firstTurnUpdate = turns[0]
} else {
secondTurnUpdate = turns[0]
}
}
if relay := update.Update.GetNetbirdConfig().GetRelay(); relay != nil {
// avoid updating on turn updates since they also send relay credentials
if update.Update.GetNetbirdConfig().GetTurns() == nil {
relayUpdates++
if relayUpdates == 1 {
firstRelayUpdate = relay
} else {
secondRelayUpdate = relay
}
}
}
}
if turnUpdates < 1 {
t.Errorf("expecting at least 1 TURN credential update, got %v", turnUpdates)
}
if relayUpdates < 1 {
t.Errorf("expecting at least 1 relay credential update, got %v", relayUpdates)
}
if firstTurnUpdate != nil && secondTurnUpdate != nil {
if firstTurnUpdate.Password == secondTurnUpdate.Password {
t.Errorf("expecting first TURN credential update password %v to be different from second, got equal", firstTurnUpdate.Password)
}
}
if firstRelayUpdate != nil && secondRelayUpdate != nil {
if firstRelayUpdate.TokenSignature == secondRelayUpdate.TokenSignature {
t.Errorf("expecting first relay credential update signature %v to be different from second, got equal", firstRelayUpdate.TokenSignature)
}
}
}
func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
ttl := util.Duration{Duration: time.Hour}
secret := "some_secret"
peersManager := update_channel.NewPeersUpdateManager(nil)
peer := "some_peer"
rc := &config.Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: ttl,
Secret: secret,
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*config.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager, groupsManager)
require.NoError(t, err)
tested.SetupRefresh(context.Background(), "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in turn cancel map, got not present")
}
if _, ok := tested.relayCancelMap[peer]; !ok {
t.Errorf("expecting peer to be present in relay cancel map, got not present")
}
tested.CancelRefresh(peer)
if _, ok := tested.turnCancelMap[peer]; ok {
t.Errorf("expecting peer to be not present in turn cancel map, got present")
}
if _, ok := tested.relayCancelMap[peer]; ok {
t.Errorf("expecting peer to be not present in relay cancel map, got present")
}
}
func validateMAC(t *testing.T, algo func() hash.Hash, username string, actualMAC string, key []byte) {
t.Helper()
mac := hmac.New(algo, key)
_, err := mac.Write([]byte(username))
if err != nil {
t.Fatal(err)
}
expectedMAC := mac.Sum(nil)
decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC)
if err != nil {
t.Fatal(err)
}
equal := hmac.Equal(decodedMAC, expectedMAC)
if !equal {
t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC)
}
}

View File

@@ -1,587 +0,0 @@
package grpc
import (
"testing"
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/shared/management/proto"
)
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
shouldSend := debouncer.ProcessUpdate(update)
if !shouldSend {
t.Error("First update should be sent immediately")
}
if debouncer.TimerChannel() == nil {
t.Error("Timer should be started after first update")
}
}
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update should be sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Rapid subsequent updates should be coalesced
if debouncer.ProcessUpdate(update2) {
t.Error("Second rapid update should not be sent immediately")
}
if debouncer.ProcessUpdate(update3) {
t.Error("Third rapid update should not be sent immediately")
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Send second update within debounce period
debouncer.ProcessUpdate(update2)
// Wait for timer
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update2 {
t.Error("Should get the last update")
}
if pendingUpdates[0] == update1 {
t.Error("Should not get the first update")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Wait a bit, but not the full debounce period
time.Sleep(30 * time.Millisecond)
// Send second update - should reset timer
debouncer.ProcessUpdate(update2)
// Wait a bit more
time.Sleep(30 * time.Millisecond)
// Send third update - should reset timer again
debouncer.ProcessUpdate(update3)
// Now wait for the timer (should fire after last update's reset)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
// Timer should be restarted since there was a pending update
if debouncer.TimerChannel() == nil {
t.Error("Timer should be restarted after sending pending update")
}
case <-time.After(150 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
debouncer.ProcessUpdate(update1)
// Second update coalesced
debouncer.ProcessUpdate(update2)
// Wait for timer to expire
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have pending update")
}
// After sending pending update, timer is restarted, so next update is NOT immediate
if debouncer.ProcessUpdate(update3) {
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
}
// Wait for the restarted timer and verify update3 is pending
select {
case <-debouncer.TimerChannel():
finalUpdates := debouncer.GetPendingUpdates()
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
t.Error("Should get update3 as pending")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired for restarted timer")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send update to start timer
debouncer.ProcessUpdate(update)
// Stop should clean up
debouncer.Stop()
// Multiple stops should be safe
debouncer.Stop()
}
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate high-frequency updates
var lastUpdate *network_map.UpdateMessage
sentImmediately := 0
for i := 0; i < 100; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
lastUpdate = update
if debouncer.ProcessUpdate(update) {
sentImmediately++
}
time.Sleep(1 * time.Millisecond) // Very rapid updates
}
// Only first update should be sent immediately
if sentImmediately != 1 {
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != lastUpdate {
t.Error("Should get the very last update")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
// Wait for timer to expire with no additional updates (true quiet period)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates")
}
// After true quiet period, timer should be cleared
if debouncer.TimerChannel() != nil {
t.Error("Timer should be cleared after quiet period")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
updates := make([]*network_map.UpdateMessage, 5)
for i := range updates {
updates[i] = &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
}
// First update sent immediately
debouncer.ProcessUpdate(updates[0])
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
debouncer.ProcessUpdate(updates[1])
debouncer.ProcessUpdate(updates[2])
debouncer.ProcessUpdate(updates[3])
debouncer.ProcessUpdate(updates[4])
// Wait for debounce
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
}
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Wait for timer without sending any more updates (true quiet period)
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates during quiet period")
}
// After true quiet period, next update should be sent immediately
if !debouncer.ProcessUpdate(update2) {
t.Error("Update after true quiet period should be sent immediately")
}
}
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate continuous high-frequency updates
for i := 0; i < 10; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
if i == 0 {
// First one sent immediately
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
} else {
// All others should be coalesced (not sent immediately)
if debouncer.ProcessUpdate(update) {
t.Errorf("Update %d should not be sent immediately", i)
}
}
// Wait a bit but send next update before debounce expires
time.Sleep(20 * time.Millisecond)
}
// Now wait for final debounce
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have the last update pending")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
tokenUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate)
// Send multiple control config updates - they should all be queued
debouncer.ProcessUpdate(tokenUpdate1)
debouncer.ProcessUpdate(tokenUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get both control config updates
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
}
// Control configs should come first
if pendingUpdates[0] != tokenUpdate1 {
t.Error("First pending update should be tokenUpdate1")
}
if pendingUpdates[1] != tokenUpdate2 {
t.Error("Second pending update should be tokenUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
netmapUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate1)
// Send token update and network map update
debouncer.ProcessUpdate(tokenUpdate)
debouncer.ProcessUpdate(netmapUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get 2 updates in order: token, then network map
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
}
// Token update should come first (preserves order)
if pendingUpdates[0] != tokenUpdate {
t.Error("First pending update should be tokenUpdate")
}
// Network map update should come second
if pendingUpdates[1] != netmapUpdate2 {
t.Error("Second pending update should be netmapUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate: 50 network maps -> 1 control config -> 50 network maps
// Expected result: 3 messages (netmap, controlConfig, netmap)
// Send first network map immediately
firstNetmap := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
MessageType: network_map.MessageTypeNetworkMap,
}
if !debouncer.ProcessUpdate(firstNetmap) {
t.Error("First update should be sent immediately")
}
// Send 49 more network maps (will be coalesced to last one)
var lastNetmapBatch1 *network_map.UpdateMessage
for i := 1; i < 50; i++ {
lastNetmapBatch1 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch1)
}
// Send 1 control config
controlConfig := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
debouncer.ProcessUpdate(controlConfig)
// Send 50 more network maps (will be coalesced to last one)
var lastNetmapBatch2 *network_map.UpdateMessage
for i := 50; i < 100; i++ {
lastNetmapBatch2 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch2)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get exactly 3 updates: netmap, controlConfig, netmap
if len(pendingUpdates) != 3 {
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
}
// First should be the last netmap from batch 1
if pendingUpdates[0] != lastNetmapBatch1 {
t.Error("First pending update should be last netmap from batch 1")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
// Second should be the control config
if pendingUpdates[1] != controlConfig {
t.Error("Second pending update should be control config")
}
// Third should be the last netmap from batch 2
if pendingUpdates[2] != lastNetmapBatch2 {
t.Error("Third pending update should be last netmap from batch 2")
}
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}

View File

@@ -1,304 +0,0 @@
//go:build integration
package grpc
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/shared/management/proto"
)
type validateSessionTestSetup struct {
proxyService *ProxyServiceServer
store store.Store
cleanup func()
}
func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
t.Helper()
ctx := context.Background()
testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir())
require.NoError(t, err)
proxyManager := &testValidateSessionProxyManager{store: testStore}
usersManager := &testValidateSessionUsersManager{store: testStore}
proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager)
proxyService.SetProxyManager(proxyManager)
createTestProxies(t, ctx, testStore)
return &validateSessionTestSetup{
proxyService: proxyService,
store: testStore,
cleanup: storeCleanup,
}
}
func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) {
t.Helper()
pubKey, privKey := generateSessionKeyPair(t)
testProxy := &reverseproxy.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
},
},
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Enabled: true,
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
},
}
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
}
func generateSessionKeyPair(t *testing.T) (string, string) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(pub), base64.StdEncoding.EncodeToString(priv)
}
func createSessionToken(t *testing.T, privKeyB64, userID, domain string) string {
t.Helper()
token, err := sessionkey.SignToken(privKeyB64, userID, domain, auth.MethodOIDC, time.Hour)
require.NoError(t, err)
return token
}
func TestValidateSession_UserAllowed(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.True(t, resp.Valid, "User should be allowed access")
assert.Equal(t, "allowedUserId", resp.UserId)
assert.Empty(t, resp.DeniedReason)
}
func TestValidateSession_UserNotInAllowedGroup(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "restrictedProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonGroupUserId", "restricted-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "restricted-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User not in group should be denied")
assert.Equal(t, "not_in_group", resp.DeniedReason)
assert.Equal(t, "nonGroupUserId", resp.UserId)
}
func TestValidateSession_UserInDifferentAccount(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "otherAccountUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "User in different account should be denied")
assert.Equal(t, "account_mismatch", resp.DeniedReason)
}
func TestValidateSession_UserNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "nonExistentUserId", "test-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Non-existent user should be denied")
assert.Equal(t, "user_not_found", resp.DeniedReason)
}
func TestValidateSession_ProxyNotFound(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
proxy, err := setup.store.GetServiceByID(context.Background(), store.LockingStrengthNone, "testAccountId", "testProxyId")
require.NoError(t, err)
token := createSessionToken(t, proxy.SessionPrivateKey, "allowedUserId", "unknown-proxy.example.com")
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "unknown-proxy.example.com",
SessionToken: token,
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Unknown proxy should be denied")
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
}
func TestValidateSession_InvalidToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
SessionToken: "invalid-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid, "Invalid token should be denied")
assert.Equal(t, "invalid_token", resp.DeniedReason)
}
func TestValidateSession_MissingDomain(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
SessionToken: "some-token",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
func TestValidateSession_MissingToken(t *testing.T) {
setup := setupValidateSessionTest(t)
defer setup.cleanup()
resp, err := setup.proxyService.ValidateSession(context.Background(), &proto.ValidateSessionRequest{
Domain: "test-proxy.example.com",
})
require.NoError(t, err)
assert.False(t, resp.Valid)
assert.Contains(t, resp.DeniedReason, "missing")
}
type testValidateSessionProxyManager struct {
store store.Store
}
func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) {
return "", nil
}
type testValidateSessionUsersManager struct {
store store.Store
}
func (m *testValidateSessionUsersManager) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
}

View File

@@ -15,7 +15,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
@@ -83,9 +83,9 @@ type DefaultAccountManager struct {
requestBuffer *AccountRequestBuffer requestBuffer *AccountRequestBuffer
proxyController port_forwarding.Controller proxyController port_forwarding.Controller
settingsManager settings.Manager settingsManager settings.Manager
reverseProxyManager reverseproxy.Manager serviceManager service.Manager
// config contains the management server configuration // config contains the management server configuration
config *nbconfig.Config config *nbconfig.Config
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
var _ account.Manager = (*DefaultAccountManager)(nil) var _ account.Manager = (*DefaultAccountManager)(nil)
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) { func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
am.reverseProxyManager = serviceManager am.serviceManager = serviceManager
} }
func isUniqueConstraintError(err error) bool { func isUniqueConstraintError(err error) bool {
@@ -394,7 +394,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
} }
if reloadReverseProxy { if reloadReverseProxy {
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil { if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err) log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"net/netip" "net/netip"
"time" "time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -140,5 +140,5 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
SetServiceManager(serviceManager reverseproxy.Manager) SetServiceManager(serviceManager service.Manager)
} }

View File

@@ -27,8 +27,8 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers" "github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
nbAccount "github.com/netbirdio/netbird/management/server/account" nbAccount "github.com/netbirdio/netbird/management/server/account"

View File

@@ -703,7 +703,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving group linked to network router", func(t *testing.T) { t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store) permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager) groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager) resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager) routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager) networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp" idpmanager "github.com/netbirdio/netbird/management/server/idp"
@@ -73,7 +73,7 @@ const (
) )
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints // Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil { if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
idp.AddEndpoints(accountManager, router) idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router) instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router)
if reverseProxyManager != nil && reverseProxyDomainManager != nil { if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
} }
// Register OAuth callback handler for proxy authentication // Register OAuth callback handler for proxy authentication

View File

@@ -18,8 +18,8 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -190,7 +190,8 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcServer := newFakeOIDCServer() oidcServer := newFakeOIDCServer()
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute) tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
usersManager := users.NewManager(testStore) usersManager := users.NewManager(testStore)

View File

@@ -12,7 +12,8 @@ import (
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -91,12 +92,16 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
} }
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager) if err != nil {
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager) t.Fatalf("Failed to create proxy token store: %v", err)
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager) }
proxyServiceServer.SetProxyManager(reverseProxyManager) proxyMgr := proxymanager.NewManager(store)
am.SetServiceManager(reverseProxyManager) proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
proxyServiceServer.SetProxyManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked // @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
@@ -114,7 +119,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil) apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil { if err != nil {
t.Fatalf("Failed to create API handler: %v", err) t.Fatalf("Failed to create API handler: %v", err)
} }

View File

@@ -12,7 +12,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
@@ -148,7 +148,7 @@ type MockAccountManager struct {
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
} }
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) { func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) {
// Mock implementation - no-op // Mock implementation - no-op
} }

View File

@@ -7,7 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
@@ -33,23 +33,23 @@ type Manager interface {
} }
type managerImpl struct { type managerImpl struct {
store store.Store store store.Store
permissionsManager permissions.Manager permissionsManager permissions.Manager
groupsManager groups.Manager groupsManager groups.Manager
accountManager account.Manager accountManager account.Manager
reverseProxyManager reverseproxy.Manager serviceManager service.Manager
} }
type mockManager struct { type mockManager struct {
} }
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager { func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
groupsManager: groupsManager, groupsManager: groupsManager,
accountManager: accountManager, accountManager: accountManager,
reverseProxyManager: reverseproxyManager, serviceManager: reverseproxyManager,
} }
} }
@@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them // TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
go func() { go func() {
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID) err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err) log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
} }
@@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID) serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
if err != nil { if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err) return fmt.Errorf("failed to check if resource is used by service: %w", err)
} }

View File

@@ -7,7 +7,7 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources/types" "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -31,8 +31,8 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err) require.NoError(t, err)
@@ -54,8 +54,8 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.Error(t, err) require.Error(t, err)
@@ -76,8 +76,8 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.NoError(t, err) require.NoError(t, err)
@@ -98,8 +98,8 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.Error(t, err) require.Error(t, err)
@@ -123,8 +123,8 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err) require.NoError(t, err)
@@ -147,8 +147,8 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err) require.Error(t, err)
@@ -176,9 +176,9 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes() serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.NoError(t, err) require.NoError(t, err)
@@ -205,8 +205,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -234,8 +234,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -262,8 +262,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource) createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -294,9 +294,9 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes() serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.NoError(t, err) require.NoError(t, err)
@@ -329,8 +329,8 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -361,8 +361,8 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -392,8 +392,8 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource) updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err) require.Error(t, err)
@@ -416,9 +416,9 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes() serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err) require.NoError(t, err)
@@ -440,8 +440,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{} am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl) serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err) require.Error(t, err)

View File

@@ -493,7 +493,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var settings *types.Settings var settings *types.Settings
var eventsToStore []func() var eventsToStore []func()
serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID) serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil { if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err) return fmt.Errorf("failed to check if resource is used by service: %w", err)
} }

View File

@@ -28,9 +28,10 @@ import (
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -131,8 +132,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{}, &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{}, &accesslogs.AccessLogEntry{}, &proxy.Proxy{},
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err) return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -2063,7 +2064,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
return checks, nil return checks, nil
} }
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth, const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster, meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key pass_host_header, rewrite_redirects, session_private_key, session_public_key
@@ -2078,8 +2079,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
return nil, err return nil, err
} }
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) { services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
var s reverseproxy.Service var s rpservice.Service
var auth []byte var auth []byte
var createdAt, certIssuedAt sql.NullTime var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
@@ -2109,7 +2110,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
} }
} }
s.Meta = reverseproxy.ServiceMeta{} s.Meta = rpservice.ServiceMeta{}
if createdAt.Valid { if createdAt.Valid {
s.Meta.CreatedAt = createdAt.Time s.Meta.CreatedAt = createdAt.Time
} }
@@ -2129,7 +2130,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
s.SessionPublicKey = sessionPublicKey.String s.SessionPublicKey = sessionPublicKey.String
} }
s.Targets = []*reverseproxy.Target{} s.Targets = []*rpservice.Target{}
return &s, nil return &s, nil
}) })
if err != nil { if err != nil {
@@ -2141,7 +2142,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
} }
serviceIDs := make([]string, len(services)) serviceIDs := make([]string, len(services))
serviceMap := make(map[string]*reverseproxy.Service) serviceMap := make(map[string]*rpservice.Service)
for i, s := range services { for i, s := range services {
serviceIDs[i] = s.ID serviceIDs[i] = s.ID
serviceMap[s.ID] = s serviceMap[s.ID] = s
@@ -2152,8 +2153,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
return nil, err return nil, err
} }
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) { targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) {
var t reverseproxy.Target var t rpservice.Target
var path sql.NullString var path sql.NullString
err := row.Scan( err := row.Scan(
&t.ID, &t.ID,
@@ -4825,7 +4826,7 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
return peerID, nil return peerID, nil
} }
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error { func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error {
serviceCopy := service.Copy() serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err) return fmt.Errorf("encrypt service data: %w", err)
@@ -4839,16 +4840,19 @@ func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Serv
return nil return nil
} }
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error { func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error {
serviceCopy := service.Copy() serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err) return fmt.Errorf("encrypt service data: %w", err)
} }
// Create target type instance outside transaction to avoid variable shadowing
targetType := &rpservice.Target{}
// Use a transaction to ensure atomic updates of the service and its targets // Use a transaction to ensure atomic updates of the service and its targets
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
// Delete existing targets // Delete existing targets
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil { if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil {
return err return err
} }
@@ -4869,7 +4873,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Serv
} }
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error { func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID) result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error) log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete service from store") return status.Errorf(status.Internal, "failed to delete service from store")
@@ -4882,13 +4886,13 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
return nil return nil
} }
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) { func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) {
tx := s.db.Preload("Targets") tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var service *reverseproxy.Service var service *rpservice.Service
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID) result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -4906,8 +4910,8 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
return service, nil return service, nil
} }
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
var service *reverseproxy.Service var service *rpservice.Service
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service) result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -4925,13 +4929,13 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str
return service, nil return service, nil
} }
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) { func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets") tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var serviceList []*reverseproxy.Service var serviceList []*rpservice.Service
result := tx.Find(&serviceList) result := tx.Find(&serviceList)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error) log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
@@ -4947,13 +4951,13 @@ func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength
return serviceList, nil return serviceList, nil
} }
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets") tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var serviceList []*reverseproxy.Service var serviceList []*rpservice.Service
result := tx.Find(&serviceList, accountIDCondition, accountID) result := tx.Find(&serviceList, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error) log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
@@ -5181,13 +5185,13 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
return query return query
} }
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) { func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) {
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
} }
var target *reverseproxy.Target var target *rpservice.Target
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID) result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -5200,3 +5204,65 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
return target, nil return target, nil
} }
// SaveProxy saves or updates a proxy in the database
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
result := s.db.WithContext(ctx).Save(p)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to save proxy")
}
return nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", time.Now())
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
return status.Errorf(status.Internal, "failed to update proxy heartbeat")
}
return nil
}
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
var addresses []string
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses")
}
return addresses, nil
}
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)
result := s.db.WithContext(ctx).
Where("last_seen < ?", cutoffTime).
Delete(&proxy.Proxy{})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error)
return status.Errorf(status.Internal, "failed to cleanup stale proxies")
}
if result.RowsAffected > 0 {
log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected)
}
return nil
}

View File

@@ -20,7 +20,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -264,7 +264,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
&types.Policy{}, &types.PolicyRule{}, &route.Route{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{},
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &types.AccountOnboarding{}, &service.Service{}, &service.Target{},
} }
for i := len(models) - 1; i >= 0; i-- { for i := len(models) - 1; i >= 0; i-- {

View File

@@ -25,9 +25,10 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -252,13 +253,13 @@ type Store interface {
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error) GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
CreateService(ctx context.Context, service *reverseproxy.Service) error CreateService(ctx context.Context, service *rpservice.Service) error
UpdateService(ctx context.Context, service *reverseproxy.Service) error UpdateService(ctx context.Context, service *rpservice.Service) error
DeleteService(ctx context.Context, accountID, serviceID string) error DeleteService(ctx context.Context, accountID, serviceID string) error
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error) ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
@@ -270,7 +271,12 @@ type Store interface {
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error)
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
} }
const ( const (

View File

@@ -12,9 +12,10 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
dns "github.com/netbirdio/netbird/dns" dns "github.com/netbirdio/netbird/dns"
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
zones "github.com/netbirdio/netbird/management/internals/modules/zones" zones "github.com/netbirdio/netbird/management/internals/modules/zones"
records "github.com/netbirdio/netbird/management/internals/modules/zones/records" records "github.com/netbirdio/netbird/management/internals/modules/zones/records"
types "github.com/netbirdio/netbird/management/server/networks/resources/types" types "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -150,6 +151,20 @@ func (mr *MockStoreMockRecorder) ApproveAccountPeers(ctx, accountID interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID)
} }
// CleanupStaleProxies mocks base method.
func (m *MockStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupStaleProxies", ctx, inactivityDuration)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupStaleProxies indicates an expected call of CleanupStaleProxies.
func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
}
// Close mocks base method. // Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error { func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -293,7 +308,7 @@ func (mr *MockStoreMockRecorder) CreatePolicy(ctx, policy interface{}) *gomock.C
} }
// CreateService mocks base method. // CreateService mocks base method.
func (m *MockStore) CreateService(ctx context.Context, service *reverseproxy.Service) error { func (m *MockStore) CreateService(ctx context.Context, service *service.Service) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateService", ctx, service) ret := m.ctrl.Call(m, "CreateService", ctx, service)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@@ -1095,10 +1110,10 @@ func (mr *MockStoreMockRecorder) GetAccountRoutes(ctx, lockStrength, accountID i
} }
// GetAccountServices mocks base method. // GetAccountServices mocks base method.
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*service.Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID) ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service) ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -1199,6 +1214,21 @@ func (mr *MockStoreMockRecorder) GetAccountsCounter(ctx interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx)
} }
// GetActiveProxyClusterAddresses mocks base method.
func (m *MockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddresses", ctx)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveProxyClusterAddresses indicates an expected call of GetActiveProxyClusterAddresses.
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
}
// GetAllAccounts mocks base method. // GetAllAccounts mocks base method.
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account { func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -1813,10 +1843,10 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
} }
// GetServiceByDomain mocks base method. // GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain) ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
ret0, _ := ret[0].(*reverseproxy.Service) ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -1828,10 +1858,10 @@ func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain inter
} }
// GetServiceByID mocks base method. // GetServiceByID mocks base method.
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) { func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*service.Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID) ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID)
ret0, _ := ret[0].(*reverseproxy.Service) ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -1843,10 +1873,10 @@ func (mr *MockStoreMockRecorder) GetServiceByID(ctx, lockStrength, accountID, se
} }
// GetServiceTargetByTargetID mocks base method. // GetServiceTargetByTargetID mocks base method.
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*reverseproxy.Target, error) { func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*service.Target, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID) ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID)
ret0, _ := ret[0].(*reverseproxy.Target) ret0, _ := ret[0].(*service.Target)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -1858,10 +1888,10 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a
} }
// GetServices mocks base method. // GetServices mocks base method.
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) { func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*service.Service, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength) ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength)
ret0, _ := ret[0].([]*reverseproxy.Service) ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@@ -2536,6 +2566,20 @@ func (mr *MockStoreMockRecorder) SavePostureChecks(ctx, postureCheck interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck)
} }
// SaveProxy mocks base method.
func (m *MockStore) SaveProxy(ctx context.Context, proxy *proxy.Proxy) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveProxy", ctx, proxy)
ret0, _ := ret[0].(error)
return ret0
}
// SaveProxy indicates an expected call of SaveProxy.
func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
}
// SaveProxyAccessToken mocks base method. // SaveProxyAccessToken mocks base method.
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error { func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -2731,8 +2775,22 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups)
} }
// UpdateProxyHeartbeat mocks base method.
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID)
}
// UpdateService mocks base method. // UpdateService mocks base method.
func (m *MockStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error { func (m *MockStore) UpdateService(ctx context.Context, service *service.Service) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateService", ctx, service) ret := m.ctrl.Call(m, "UpdateService", ctx, service)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)

View File

@@ -18,7 +18,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth" "github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -100,7 +100,7 @@ type Account struct {
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"` Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings // Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
@@ -906,7 +906,7 @@ func (a *Account) Copy() *Account {
networkResources = append(networkResources, resource.Copy()) networkResources = append(networkResources, resource.Copy())
} }
services := []*reverseproxy.Service{} services := []*service.Service{}
for _, service := range a.Services { for _, service := range a.Services {
services = append(services, service.Copy()) services = append(services, service.Copy())
} }
@@ -1814,7 +1814,7 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
} }
} }
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) { func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
for _, target := range service.Targets { for _, target := range service.Targets {
if !target.Enabled { if !target.Enabled {
continue continue
@@ -1823,7 +1823,7 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *rever
} }
} }
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) { func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
port, ok := a.resolveTargetPort(ctx, target) port, ok := a.resolveTargetPort(ctx, target)
if !ok { if !ok {
return return
@@ -1840,7 +1840,7 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *revers
} }
} }
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) { func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) {
if target.Port != 0 { if target.Port != 0 {
return target.Port, true return target.Port, true
} }
@@ -1856,7 +1856,7 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Ta
} }
} }
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy { func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path) policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
return &Policy{ return &Policy{
ID: policyID, ID: policyID,

View File

@@ -1,94 +0,0 @@
package proxy
import (
"context"
"io"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/shared/management/proto"
)
type mockMappingStream struct {
grpc.ClientStream
messages []*proto.GetMappingUpdateResponse
idx int
}
func (m *mockMappingStream) Recv() (*proto.GetMappingUpdateResponse, error) {
if m.idx >= len(m.messages) {
return nil, io.EOF
}
msg := m.messages[m.idx]
m.idx++
return msg, nil
}
func (m *mockMappingStream) Header() (metadata.MD, error) {
return nil, nil //nolint:nilnil
}
func (m *mockMappingStream) Trailer() metadata.MD { return nil }
func (m *mockMappingStream) CloseSend() error { return nil }
func (m *mockMappingStream) Context() context.Context { return context.Background() }
func (m *mockMappingStream) SendMsg(any) error { return nil }
func (m *mockMappingStream) RecvMsg(any) error { return nil }
func TestHandleMappingStream_SyncCompleteFlag(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{InitialSyncComplete: true},
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "initial sync should be marked done when flag is set")
}
func TestHandleMappingStream_NoSyncFlagDoesNotMarkDone(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // no sync flag
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.False(t, syncDone, "initial sync should not be marked done without flag")
}
func TestHandleMappingStream_NilHealthChecker(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{InitialSyncComplete: true},
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "sync done flag should be set even without health checker")
}

View File

@@ -1,560 +0,0 @@
package proxy
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/proxy"
proxytypes "github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// integrationTestSetup contains all real components for testing.
type integrationTestSetup struct {
store store.Store
proxyService *nbgrpc.ProxyServiceServer
grpcServer *grpc.Server
grpcAddr string
cleanup func()
services []*reverseproxy.Service
}
func setupIntegrationTest(t *testing.T) *integrationTestSetup {
t.Helper()
ctx := context.Background()
// Create real SQLite store
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
// Create test account
testAccount := &types.Account{
Id: "test-account-1",
Domain: "test.com",
DomainCategory: "private",
IsDomainPrimaryAccount: true,
CreatedAt: time.Now(),
}
require.NoError(t, testStore.SaveAccount(ctx, testAccount))
// Generate session keys for reverse proxies
pub, priv, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
// Create test services in the store
services := []*reverseproxy.Service{
{
ID: "rp-1",
AccountID: "test-account-1",
Name: "Test App 1",
Domain: "app1.test.proxy.io",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "10.0.0.1",
Port: 8080,
Protocol: "http",
TargetId: "peer1",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
ProxyCluster: "test.proxy.io",
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
},
{
ID: "rp-2",
AccountID: "test-account-1",
Name: "Test App 2",
Domain: "app2.test.proxy.io",
Targets: []*reverseproxy.Target{{
Path: strPtr("/"),
Host: "10.0.0.2",
Port: 8080,
Protocol: "http",
TargetId: "peer2",
TargetType: "peer",
Enabled: true,
}},
Enabled: true,
ProxyCluster: "test.proxy.io",
SessionPrivateKey: privKey,
SessionPublicKey: pubKey,
},
}
for _, svc := range services {
require.NoError(t, testStore.CreateService(ctx, svc))
}
// Create real token store
tokenStore := nbgrpc.NewOneTimeTokenStore(5 * time.Minute)
// Create real users manager
usersManager := users.NewManager(testStore)
// Create real proxy service server with minimal config
oidcConfig := nbgrpc.ProxyOIDCConfig{
Issuer: "https://fake-issuer.example.com",
ClientID: "test-client",
HMACKey: []byte("test-hmac-key"),
}
proxyService := nbgrpc.NewProxyServiceServer(
&testAccessLogManager{},
tokenStore,
oidcConfig,
nil,
usersManager,
)
// Use store-backed service manager
svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore}
proxyService.SetProxyManager(svcMgr)
// Start real gRPC server
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
grpcServer := grpc.NewServer()
proto.RegisterProxyServiceServer(grpcServer, proxyService)
go func() {
if err := grpcServer.Serve(lis); err != nil {
t.Logf("gRPC server error: %v", err)
}
}()
return &integrationTestSetup{
store: testStore,
proxyService: proxyService,
grpcServer: grpcServer,
grpcAddr: lis.Addr().String(),
services: services,
cleanup: func() {
grpcServer.GracefulStop()
cleanup()
},
}
}
// testAccessLogManager provides access log storage for testing.
type testAccessLogManager struct{}
func (m *testAccessLogManager) CleanupOldAccessLogs(ctx context.Context, retentionDays int) (int64, error) {
return 0, nil
}
func (m *testAccessLogManager) StartPeriodicCleanup(ctx context.Context, retentionDays, cleanupIntervalHours int) {
// noop
}
func (m *testAccessLogManager) StopPeriodicCleanup() {
// noop
}
func (m *testAccessLogManager) SaveAccessLog(_ context.Context, _ *accesslogs.AccessLogEntry) error {
return nil
}
func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, _ *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
return nil, 0, nil
}
// storeBackedServiceManager reads directly from the real store.
type storeBackedServiceManager struct {
store store.Store
tokenStore *nbgrpc.OneTimeTokenStore
}
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
}
func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, errors.New("not implemented")
}
func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return nil, errors.New("not implemented")
}
func (m *storeBackedServiceManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
return nil
}
func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
return nil
}
func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
return nil
}
func (m *storeBackedServiceManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
return nil
}
func (m *storeBackedServiceManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
return nil
}
func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1")
}
func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
}
func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
func (m *storeBackedServiceManager) GetServiceIDByTargetID(ctx context.Context, accountID string, targetID string) (string, error) {
return "", nil
}
func strPtr(s string) *string {
return &s
}
func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "test-proxy-1",
Version: "test-v1",
Address: "test.proxy.io",
})
require.NoError(t, err)
// Receive all mappings from the snapshot - server sends each mapping individually
mappingsByID := make(map[string]*proto.ProxyMapping)
for i := 0; i < 2; i++ {
msg, err := stream.Recv()
require.NoError(t, err)
for _, m := range msg.GetMapping() {
mappingsByID[m.GetId()] = m
}
}
// Should receive 2 mappings total
assert.Len(t, mappingsByID, 2, "Should receive 2 reverse proxy mappings")
rp1 := mappingsByID["rp-1"]
require.NotNil(t, rp1)
assert.Equal(t, "app1.test.proxy.io", rp1.GetDomain())
assert.Equal(t, "test-account-1", rp1.GetAccountId())
assert.Equal(t, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED, rp1.GetType())
assert.NotEmpty(t, rp1.GetAuthToken(), "Should have auth token for peer creation")
rp2 := mappingsByID["rp-2"]
require.NotNil(t, rp2)
assert.Equal(t, "app2.test.proxy.io", rp2.GetDomain())
}
func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
clusterAddress := "test.proxy.io"
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: "test-proxy-cluster",
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
mappings := make([]*proto.ProxyMapping, 0)
for i := 0; i < 2; i++ {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
}
// Should receive the 2 mappings matching the cluster
assert.Len(t, mappings, 2, "Should receive mappings for the cluster")
for _, mapping := range mappings {
t.Logf("Received mapping: id=%s domain=%s", mapping.GetId(), mapping.GetDomain())
}
}
func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
clusterAddress := "test.proxy.io"
proxyID := "test-proxy-reconnect"
// Helper to receive all mappings from a stream
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
var mappings []*proto.ProxyMapping
for i := 0; i < count; i++ {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
}
return mappings
}
// First connection
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
firstMappings := receiveMappings(stream1, 2)
cancel1()
time.Sleep(100 * time.Millisecond)
// Second connection (simulating reconnect)
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
secondMappings := receiveMappings(stream2, 2)
// Should receive the same mappings
assert.Equal(t, len(firstMappings), len(secondMappings),
"Should receive same number of mappings on reconnect")
firstIDs := make(map[string]bool)
for _, m := range firstMappings {
firstIDs[m.GetId()] = true
}
for _, m := range secondMappings {
assert.True(t, firstIDs[m.GetId()],
"Mapping %s should be present in both connections", m.GetId())
}
}
func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
// Use real auth middleware and proxy to verify idempotency
logger := log.New()
logger.SetLevel(log.WarnLevel)
authMw := auth.NewMiddleware(logger, nil)
proxyHandler := proxy.NewReverseProxy(nil, "auto", nil, logger)
clusterAddress := "test.proxy.io"
proxyID := "test-proxy-idempotent"
var addMappingCalls atomic.Int32
applyMappings := func(mappings []*proto.ProxyMapping) {
for _, mapping := range mappings {
if mapping.GetType() == proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED {
addMappingCalls.Add(1)
// Apply to real auth middleware (idempotent)
err := authMw.AddDomain(
mapping.GetDomain(),
nil,
"",
0,
mapping.GetAccountId(),
mapping.GetId(),
)
require.NoError(t, err)
// Apply to real proxy (idempotent)
proxyHandler.AddMapping(proxy.Mapping{
Host: mapping.GetDomain(),
ID: mapping.GetId(),
AccountID: proxytypes.AccountID(mapping.GetAccountId()),
})
}
}
}
// Helper to receive and apply all mappings
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
for i := 0; i < 2; i++ {
msg, err := stream.Recv()
require.NoError(t, err)
applyMappings(msg.GetMapping())
}
}
// First connection
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
stream1, err := client.GetMappingUpdate(ctx1, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
receiveAndApply(stream1)
cancel1()
firstCallCount := addMappingCalls.Load()
t.Logf("First connection: applied %d mappings", firstCallCount)
time.Sleep(100 * time.Millisecond)
// Second connection
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
stream2, err := client.GetMappingUpdate(ctx2, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
receiveAndApply(stream2)
cancel2()
time.Sleep(100 * time.Millisecond)
// Third connection
ctx3, cancel3 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel3()
stream3, err := client.GetMappingUpdate(ctx3, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
receiveAndApply(stream3)
totalCalls := addMappingCalls.Load()
t.Logf("After three connections: total applied %d mappings", totalCalls)
// Should have called addMapping 6 times (2 mappings x 3 connections)
// But internal state is NOT duplicated because auth and proxy use maps keyed by domain/host
assert.Equal(t, int32(6), totalCalls, "Should have 6 total calls (2 mappings x 3 connections)")
}
func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T) {
setup := setupIntegrationTest(t)
defer setup.cleanup()
clusterAddress := "test.proxy.io"
var wg sync.WaitGroup
var mu sync.Mutex
receivedByProxy := make(map[string]int)
for i := 1; i <= 3; i++ {
wg.Add(1)
go func(proxyNum int) {
defer wg.Done()
conn, err := grpc.NewClient(setup.grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer conn.Close()
client := proto.NewProxyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
proxyID := "test-proxy-" + string(rune('A'+proxyNum-1))
stream, err := client.GetMappingUpdate(ctx, &proto.GetMappingUpdateRequest{
ProxyId: proxyID,
Version: "test-v1",
Address: clusterAddress,
})
require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
count := 0
for i := 0; i < 2; i++ {
msg, err := stream.Recv()
require.NoError(t, err)
count += len(msg.GetMapping())
}
mu.Lock()
receivedByProxy[proxyID] = count
mu.Unlock()
}(i)
}
wg.Wait()
for proxyID, count := range receivedByProxy {
assert.Equal(t, 2, count, "Proxy %s should receive 2 mappings", proxyID)
}
}

View File

@@ -1,106 +0,0 @@
package proxy
import (
"net"
"net/netip"
"testing"
"time"
proxyproto "github.com/pires/go-proxyproto"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWrapProxyProtocol_OverridesRemoteAddr(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")},
ProxyProtocol: true,
}
raw, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer raw.Close()
ln := srv.wrapProxyProtocol(raw)
realClientIP := "203.0.113.50"
realClientPort := uint16(54321)
accepted := make(chan net.Conn, 1)
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
accepted <- conn
}()
// Connect and send a PROXY v2 header.
conn, err := net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
defer conn.Close()
header := &proxyproto.Header{
Version: 2,
Command: proxyproto.PROXY,
TransportProtocol: proxyproto.TCPv4,
SourceAddr: &net.TCPAddr{IP: net.ParseIP(realClientIP), Port: int(realClientPort)},
DestinationAddr: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 443},
}
_, err = header.WriteTo(conn)
require.NoError(t, err)
select {
case accepted := <-accepted:
defer accepted.Close()
host, _, err := net.SplitHostPort(accepted.RemoteAddr().String())
require.NoError(t, err)
assert.Equal(t, realClientIP, host, "RemoteAddr should reflect the PROXY header source IP")
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for connection")
}
}
func TestProxyProtocolPolicy_TrustedRequires(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.TCPAddr{IP: net.ParseIP("10.0.0.1"), Port: 1234},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.REQUIRE, policy, "trusted source should require PROXY header")
}
func TestProxyProtocolPolicy_UntrustedIgnores(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.TCPAddr{IP: net.ParseIP("203.0.113.50"), Port: 1234},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.IGNORE, policy, "untrusted source should have PROXY header ignored")
}
func TestProxyProtocolPolicy_InvalidIPRejects(t *testing.T) {
srv := &Server{
Logger: log.StandardLogger(),
TrustedProxies: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
}
opts := proxyproto.ConnPolicyOptions{
Upstream: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
}
policy, err := srv.proxyProtocolPolicy(opts)
require.NoError(t, err)
assert.Equal(t, proxyproto.REJECT, policy, "unparsable address should be rejected")
}

View File

@@ -1,48 +0,0 @@
package proxy
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestDebugEndpointDisabledByDefault(t *testing.T) {
s := &Server{}
assert.False(t, s.DebugEndpointEnabled, "debug endpoint should be disabled by default")
}
func TestDebugEndpointAddr(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty defaults to localhost",
input: "",
expected: "localhost:8444",
},
{
name: "explicit localhost preserved",
input: "localhost:9999",
expected: "localhost:9999",
},
{
name: "explicit address preserved",
input: "0.0.0.0:8444",
expected: "0.0.0.0:8444",
},
{
name: "127.0.0.1 preserved",
input: "127.0.0.1:8444",
expected: "127.0.0.1:8444",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := debugEndpointAddr(tc.input)
assert.Equal(t, tc.expected, result)
})
}
}

View File

@@ -1,90 +0,0 @@
package proxy
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestParseTrustedProxies(t *testing.T) {
tests := []struct {
name string
raw string
want []netip.Prefix
wantErr bool
}{
{
name: "empty string returns nil",
raw: "",
want: nil,
},
{
name: "single CIDR",
raw: "10.0.0.0/8",
want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
},
{
name: "single bare IPv4",
raw: "1.2.3.4",
want: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/32")},
},
{
name: "single bare IPv6",
raw: "::1",
want: []netip.Prefix{netip.MustParsePrefix("::1/128")},
},
{
name: "comma-separated CIDRs",
raw: "10.0.0.0/8, 192.168.1.0/24",
want: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "mixed CIDRs and bare IPs",
raw: "10.0.0.0/8, 1.2.3.4, fd00::/8",
want: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("1.2.3.4/32"),
netip.MustParsePrefix("fd00::/8"),
},
},
{
name: "whitespace around entries",
raw: " 10.0.0.0/8 , 192.168.0.0/16 ",
want: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "trailing comma produces no extra entry",
raw: "10.0.0.0/8,",
want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
},
{
name: "invalid entry",
raw: "not-an-ip",
wantErr: true,
},
{
name: "partially invalid",
raw: "10.0.0.0/8, garbage",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ParseTrustedProxies(tt.raw)
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}