move service manager

This commit is contained in:
pascal
2026-02-20 01:21:05 +01:00
parent d4d885d434
commit 3af287ebab
27 changed files with 267 additions and 264 deletions

View File

@@ -0,0 +1,23 @@
package service
//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
import (
"context"
)
type Manager interface {
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
ReloadService(ctx context.Context, accountID, serviceID string) error
GetGlobalServices(ctx context.Context) ([]*Service, error)
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
}

View File

@@ -0,0 +1,225 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ./interface.go
// Package service is a generated GoMock package.
package service
import (
context "context"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
)
// MockManager is a mock of Manager interface.
type MockManager struct {
ctrl *gomock.Controller
recorder *MockManagerMockRecorder
}
// MockManagerMockRecorder is the mock recorder for MockManager.
type MockManagerMockRecorder struct {
mock *MockManager
}
// NewMockManager creates a new mock instance.
func NewMockManager(ctrl *gomock.Controller) *MockManager {
mock := &MockManager{ctrl: ctrl}
mock.recorder = &MockManagerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
return m.recorder
}
// CreateService mocks base method.
func (m *MockManager) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateService", ctx, accountID, userID, service)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateService indicates an expected call of CreateService.
func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteService", ctx, accountID, userID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteService indicates an expected call of DeleteService.
func (mr *MockManagerMockRecorder) DeleteService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockManager)(nil).DeleteService), ctx, accountID, userID, serviceID)
}
// GetAccountServices mocks base method.
func (m *MockManager) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountServices", ctx, accountID)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAccountServices indicates an expected call of GetAccountServices.
func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
}
// GetAllServices mocks base method.
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllServices", ctx, accountID, userID)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAllServices indicates an expected call of GetAllServices.
func (mr *MockManagerMockRecorder) GetAllServices(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServices", reflect.TypeOf((*MockManager)(nil).GetAllServices), ctx, accountID, userID)
}
// GetGlobalServices mocks base method.
func (m *MockManager) GetGlobalServices(ctx context.Context) ([]*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGlobalServices", ctx)
ret0, _ := ret[0].([]*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGlobalServices indicates an expected call of GetGlobalServices.
func (mr *MockManagerMockRecorder) GetGlobalServices(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGlobalServices", reflect.TypeOf((*MockManager)(nil).GetGlobalServices), ctx)
}
// GetService mocks base method.
func (m *MockManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetService", ctx, accountID, userID, serviceID)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetService indicates an expected call of GetService.
func (mr *MockManagerMockRecorder) GetService(ctx, accountID, userID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetService", reflect.TypeOf((*MockManager)(nil).GetService), ctx, accountID, userID, serviceID)
}
// GetServiceByID mocks base method.
func (m *MockManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByID", ctx, accountID, serviceID)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceByID indicates an expected call of GetServiceByID.
func (mr *MockManagerMockRecorder) GetServiceByID(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceByID", reflect.TypeOf((*MockManager)(nil).GetServiceByID), ctx, accountID, serviceID)
}
// GetServiceIDByTargetID mocks base method.
func (m *MockManager) GetServiceIDByTargetID(ctx context.Context, accountID, resourceID string) (string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceIDByTargetID", ctx, accountID, resourceID)
ret0, _ := ret[0].(string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServiceIDByTargetID indicates an expected call of GetServiceIDByTargetID.
func (mr *MockManagerMockRecorder) GetServiceIDByTargetID(ctx, accountID, resourceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceIDByTargetID", reflect.TypeOf((*MockManager)(nil).GetServiceIDByTargetID), ctx, accountID, resourceID)
}
// ReloadAllServicesForAccount mocks base method.
func (m *MockManager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReloadAllServicesForAccount", ctx, accountID)
ret0, _ := ret[0].(error)
return ret0
}
// ReloadAllServicesForAccount indicates an expected call of ReloadAllServicesForAccount.
func (mr *MockManagerMockRecorder) ReloadAllServicesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadAllServicesForAccount", reflect.TypeOf((*MockManager)(nil).ReloadAllServicesForAccount), ctx, accountID)
}
// ReloadService mocks base method.
func (m *MockManager) ReloadService(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ReloadService", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// ReloadService indicates an expected call of ReloadService.
func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
}
// SetCertificateIssuedAt mocks base method.
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetCertificateIssuedAt", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// SetCertificateIssuedAt indicates an expected call of SetCertificateIssuedAt.
func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetCertificateIssuedAt", reflect.TypeOf((*MockManager)(nil).SetCertificateIssuedAt), ctx, accountID, serviceID)
}
// SetStatus mocks base method.
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
ret0, _ := ret[0].(error)
return ret0
}
// SetStatus indicates an expected call of SetStatus.
func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
}
// UpdateService mocks base method.
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateService", ctx, accountID, userID, service)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateService indicates an expected call of UpdateService.
func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
}

View File

@@ -1,4 +1,4 @@
package service
package manager
import (
"encoding/json"
@@ -6,10 +6,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -17,11 +17,11 @@ import (
)
type handler struct {
manager reverseproxy.Manager
manager rpservice.Manager
}
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
h := &handler{
manager: manager,
}
@@ -72,7 +72,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
return
}
service := new(reverseproxy.Service)
service := new(rpservice.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
@@ -130,7 +130,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
return
}
service := new(reverseproxy.Service)
service := new(rpservice.Service)
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)

View File

@@ -1,4 +1,4 @@
package service
package manager
import (
"context"
@@ -7,7 +7,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -36,7 +36,7 @@ type Manager struct {
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) rpservice.Manager {
return &Manager{
store: store,
accountManager: accountManager,
@@ -46,7 +46,7 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
}
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*rpservice.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -70,10 +70,10 @@ func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string)
return services, nil
}
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, service *rpservice.Service) error {
for _, target := range service.Targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
case rpservice.TargetTypePeer:
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
@@ -81,7 +81,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, ser
continue
}
target.Host = peer.IP.String()
case reverseproxy.TargetTypeHost:
case rpservice.TargetTypeHost:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
@@ -89,7 +89,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, ser
continue
}
target.Host = resource.Prefix.Addr().String()
case reverseproxy.TargetTypeDomain:
case rpservice.TargetTypeDomain:
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
if err != nil {
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
@@ -97,7 +97,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, ser
continue
}
target.Host = resource.Domain
case reverseproxy.TargetTypeSubnet:
case rpservice.TargetTypeSubnet:
// For subnets we do not do any lookups on the resource
default:
return fmt.Errorf("unknown target type: %s", target.TargetType)
@@ -106,7 +106,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, ser
return nil
}
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*rpservice.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -127,7 +127,7 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s
return service, nil
}
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, service *rpservice.Service) (*rpservice.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -151,14 +151,14 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *rpservice.Service) error {
if m.clusterDeriver != nil {
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
@@ -185,7 +185,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return nil
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *rpservice.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
return err
@@ -219,7 +219,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
return nil
}
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *rpservice.Service) (*rpservice.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
@@ -255,7 +255,7 @@ type serviceUpdateInfo struct {
serviceEnabledChanged bool
}
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *rpservice.Service) (*serviceUpdateInfo, error) {
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -293,7 +293,7 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
return &updateInfo, err
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *rpservice.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
return err
}
@@ -310,7 +310,7 @@ func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Stor
return nil
}
func (m *Manager) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
func (m *Manager) preserveExistingAuthSecrets(service, existingService *rpservice.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
@@ -328,40 +328,40 @@ func (m *Manager) SendServiceUpdateToCluster(accountID string, update *proto.Pro
m.proxyGRPCServer.SendServiceUpdateToCluster(update, clusterAddr)
}
func (m *Manager) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
func (m *Manager) preserveServiceMetadata(service, existingService *rpservice.Service) {
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *Manager) sendServiceUpdateNotifications(accountID string, service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
func (m *Manager) sendServiceUpdateNotifications(accountID string, service *rpservice.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
switch {
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Delete, "", oidcCfg), updateInfo.oldCluster)
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Create, "", oidcCfg), service.ProxyCluster)
case !service.Enabled && updateInfo.serviceEnabledChanged:
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Delete, "", oidcCfg), service.ProxyCluster)
case service.Enabled && updateInfo.serviceEnabledChanged:
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Create, "", oidcCfg), service.ProxyCluster)
default:
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Update, "", oidcCfg), service.ProxyCluster)
}
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*rpservice.Target) error {
for _, target := range targets {
switch target.TargetType {
case reverseproxy.TargetTypePeer:
case rpservice.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
case rpservice.TargetTypeHost, rpservice.TargetTypeSubnet, rpservice.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
@@ -382,7 +382,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
return status.NewPermissionDeniedError()
}
var service *reverseproxy.Service
var service *rpservice.Service
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
@@ -402,7 +402,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
@@ -429,7 +429,7 @@ func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, service
}
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status rpservice.Status) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
if err != nil {
@@ -457,7 +457,7 @@ func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
@@ -475,13 +475,13 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
if err != nil {
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
}
return nil
}
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*rpservice.Service, error) {
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)
@@ -497,7 +497,7 @@ func (m *Manager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Servic
return services, nil
}
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*rpservice.Service, error) {
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
if err != nil {
return nil, fmt.Errorf("failed to get service: %w", err)
@@ -511,7 +511,7 @@ func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID strin
return service, nil
}
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get services: %w", err)

View File

@@ -1,4 +1,4 @@
package service
package manager
import (
"context"
@@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -24,9 +24,9 @@ func TestInitializeServiceForCreate(t *testing.T) {
clusterDeriver: nil,
}
service := &reverseproxy.Service{
service := &rpservice.Service{
Domain: "example.com",
Auth: reverseproxy.AuthConfig{},
Auth: rpservice.AuthConfig{},
}
err := mgr.initializeServiceForCreate(ctx, accountID, service)
@@ -44,8 +44,8 @@ func TestInitializeServiceForCreate(t *testing.T) {
clusterDeriver: nil,
}
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
service1 := &rpservice.Service{Domain: "test1.com", Auth: rpservice.AuthConfig{}}
service2 := &rpservice.Service{Domain: "test2.com", Auth: rpservice.AuthConfig{}}
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
@@ -87,7 +87,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
@@ -99,7 +99,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: false,
},
@@ -110,7 +110,7 @@ func TestCheckDomainAvailable(t *testing.T) {
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: true,
errorType: status.AlreadyExists,
@@ -179,7 +179,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "test.com").
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
@@ -215,10 +215,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
service := &rpservice.Service{
ID: "service-123",
Domain: "new.com",
Targets: []*reverseproxy.Target{},
Targets: []*rpservice.Target{},
}
// Mock ExecuteInTransaction to execute the function immediately
@@ -248,10 +248,10 @@ func TestPersistNewService(t *testing.T) {
defer ctrl.Finish()
mockStore := store.NewMockStore(ctrl)
service := &reverseproxy.Service{
service := &rpservice.Service{
ID: "service-123",
Domain: "existing.com",
Targets: []*reverseproxy.Target{},
Targets: []*rpservice.Target{},
}
mockStore.EXPECT().
@@ -260,7 +260,7 @@ func TestPersistNewService(t *testing.T) {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "existing.com").
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
return fn(txMock)
})
@@ -278,18 +278,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
mgr := &Manager{}
t.Run("preserve password when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "hashed-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "",
},
@@ -302,18 +302,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
})
t.Run("preserve pin when empty", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PinAuth: &rpservice.PINAuthConfig{
Enabled: true,
Pin: "hashed-pin",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PinAuth: &reverseproxy.PINAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PinAuth: &rpservice.PINAuthConfig{
Enabled: true,
Pin: "",
},
@@ -326,18 +326,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
})
t.Run("do not preserve when password is provided", func(t *testing.T) {
existing := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
existing := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "old-password",
},
},
}
updated := &reverseproxy.Service{
Auth: reverseproxy.AuthConfig{
PasswordAuth: &reverseproxy.PasswordAuthConfig{
updated := &rpservice.Service{
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{
Enabled: true,
Password: "new-password",
},
@@ -354,8 +354,8 @@ func TestPreserveExistingAuthSecrets(t *testing.T) {
func TestPreserveServiceMetadata(t *testing.T) {
mgr := &Manager{}
existing := &reverseproxy.Service{
Meta: reverseproxy.ServiceMeta{
existing := &rpservice.Service{
Meta: rpservice.ServiceMeta{
CertificateIssuedAt: time.Now(),
Status: "active",
},
@@ -363,7 +363,7 @@ func TestPreserveServiceMetadata(t *testing.T) {
SessionPublicKey: "public-key",
}
updated := &reverseproxy.Service{
updated := &rpservice.Service{
Domain: "updated.com",
}

View File

@@ -0,0 +1,463 @@
package service
import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/util/crypt"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Operation string
const (
Create Operation = "create"
Update Operation = "update"
Delete Operation = "delete"
)
type Status string
const (
StatusPending Status = "pending"
StatusActive Status = "active"
StatusTunnelNotCreated Status = "tunnel_not_created"
StatusCertificatePending Status = "certificate_pending"
StatusCertificateFailed Status = "certificate_failed"
StatusError Status = "error"
TargetTypePeer = "peer"
TargetTypeHost = "host"
TargetTypeDomain = "domain"
TargetTypeSubnet = "subnet"
)
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
}
type PasswordAuthConfig struct {
Enabled bool `json:"enabled"`
Password string `json:"password"`
}
type PINAuthConfig struct {
Enabled bool `json:"enabled"`
Pin string `json:"pin"`
}
type BearerAuthConfig struct {
Enabled bool `json:"enabled"`
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
}
type AuthConfig struct {
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
}
func (a *AuthConfig) HashSecrets() error {
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
a.PasswordAuth.Password = hashedPassword
}
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
if err != nil {
return fmt.Errorf("hash pin: %w", err)
}
a.PinAuth.Pin = hashedPin
}
return nil
}
func (a *AuthConfig) ClearSecrets() {
if a.PasswordAuth != nil {
a.PasswordAuth.Password = ""
}
if a.PinAuth != nil {
a.PinAuth.Pin = ""
}
}
type OIDCValidationConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
type ServiceMeta struct {
CreatedAt time.Time
CertificateIssuedAt time.Time
Status string
}
type Service struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string `gorm:"index"`
ProxyCluster string `gorm:"index"`
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
Enabled bool
PassHostHeader bool
RewriteRedirects bool
Auth AuthConfig `gorm:"serializer:json"`
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
SessionPrivateKey string `gorm:"column:session_private_key"`
SessionPublicKey string `gorm:"column:session_public_key"`
}
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
for _, target := range targets {
target.AccountID = accountID
}
s := &Service{
AccountID: accountID,
Name: name,
Domain: domain,
ProxyCluster: proxyCluster,
Targets: targets,
Enabled: enabled,
}
s.InitNewRecord()
return s
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
// Service record. This overwrites any existing ID and Meta fields and should
// only be called during initial creation, not for updates.
func (s *Service) InitNewRecord() {
s.ID = xid.New().String()
s.Meta = ServiceMeta{
CreatedAt: time.Now(),
Status: string(StatusPending),
}
}
func (s *Service) ToAPIResponse() *api.Service {
s.Auth.ClearSecrets()
authConfig := api.ServiceAuthConfig{}
if s.Auth.PasswordAuth != nil {
authConfig.PasswordAuth = &api.PasswordAuthConfig{
Enabled: s.Auth.PasswordAuth.Enabled,
Password: s.Auth.PasswordAuth.Password,
}
}
if s.Auth.PinAuth != nil {
authConfig.PinAuth = &api.PINAuthConfig{
Enabled: s.Auth.PinAuth.Enabled,
Pin: s.Auth.PinAuth.Pin,
}
}
if s.Auth.BearerAuth != nil {
authConfig.BearerAuth = &api.BearerAuthConfig{
Enabled: s.Auth.BearerAuth.Enabled,
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
}
}
// Convert internal targets to API targets
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
for _, target := range s.Targets {
apiTargets = append(apiTargets, api.ServiceTarget{
Path: target.Path,
Host: &target.Host,
Port: target.Port,
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
})
}
meta := api.ServiceMeta{
CreatedAt: s.Meta.CreatedAt,
Status: api.ServiceMetaStatus(s.Meta.Status),
}
if !s.Meta.CertificateIssuedAt.IsZero() {
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
}
resp := &api.Service{
Id: s.ID,
Name: s.Name,
Domain: s.Domain,
Targets: apiTargets,
Enabled: s.Enabled,
PassHostHeader: &s.PassHostHeader,
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
}
if s.ProxyCluster != "" {
resp.ProxyCluster = &s.ProxyCluster
}
return resp
}
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {
continue
}
// TODO: Make path prefix stripping configurable per-target.
// Currently the matching prefix is baked into the target URL path,
// so the proxy strips-then-re-adds it (effectively a no-op).
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: "/", // TODO: support service path
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
path := "/"
if target.Path != nil {
path = *target.Path
}
pathMappings = append(pathMappings, &proto.PathMapping{
Path: path,
Target: targetURL.String(),
})
}
auth := &proto.Authentication{
SessionKey: s.SessionPublicKey,
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
}
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
auth.Password = true
}
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
auth.Pin = true
}
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
auth.Oidc = true
}
return &proto.ProxyMapping{
Type: operationToProtoType(operation),
Id: s.ID,
Domain: s.Domain,
Path: pathMappings,
AuthToken: authToken,
Auth: auth,
AccountId: s.AccountID,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
}
}
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
switch op {
case Create:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
case Update:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
case Delete:
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
default:
log.Fatalf("unknown operation type: %v", op)
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
}
}
// isDefaultPort reports whether port is the standard default for the given scheme
// (443 for https, 80 for http).
func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
s.Name = req.Name
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
for _, apiTarget := range req.Targets {
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
Port: apiTarget.Port,
Protocol: string(apiTarget.Protocol),
TargetId: apiTarget.TargetId,
TargetType: string(apiTarget.TargetType),
Enabled: apiTarget.Enabled,
}
if apiTarget.Host != nil {
target.Host = *apiTarget.Host
}
targets = append(targets, target)
}
s.Targets = targets
s.Enabled = req.Enabled
if req.PassHostHeader != nil {
s.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
s.RewriteRedirects = *req.RewriteRedirects
}
if req.Auth.PasswordAuth != nil {
s.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: req.Auth.PasswordAuth.Enabled,
Password: req.Auth.PasswordAuth.Password,
}
}
if req.Auth.PinAuth != nil {
s.Auth.PinAuth = &PINAuthConfig{
Enabled: req.Auth.PinAuth.Enabled,
Pin: req.Auth.PinAuth.Pin,
}
}
if req.Auth.BearerAuth != nil {
bearerAuth := &BearerAuthConfig{
Enabled: req.Auth.BearerAuth.Enabled,
}
if req.Auth.BearerAuth.DistributionGroups != nil {
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
}
s.Auth.BearerAuth = bearerAuth
}
}
func (s *Service) Validate() error {
if s.Name == "" {
return errors.New("service name is required")
}
if len(s.Name) > 255 {
return errors.New("service name exceeds maximum length of 255 characters")
}
if s.Domain == "" {
return errors.New("service domain is required")
}
if len(s.Targets) == 0 {
return errors.New("at least one target is required")
}
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
// host field will be ignored
case TargetTypeSubnet:
if target.Host == "" {
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
}
default:
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
}
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
}
return nil
}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
}
func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
targets[i] = &targetCopy
}
return &Service{
ID: s.ID,
AccountID: s.AccountID,
Name: s.Name,
Domain: s.Domain,
ProxyCluster: s.ProxyCluster,
Targets: targets,
Enabled: s.Enabled,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: s.Auth,
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,
}
}
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
if s.SessionPrivateKey != "" {
var err error
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
if err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,405 @@
package service
import (
"errors"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/hash/argon2id"
"github.com/netbirdio/netbird/shared/management/proto"
)
func validProxy() *Service {
return &Service{
Name: "test",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
},
}
}
func TestValidate_Valid(t *testing.T) {
require.NoError(t, validProxy().Validate())
}
func TestValidate_EmptyName(t *testing.T) {
rp := validProxy()
rp.Name = ""
assert.ErrorContains(t, rp.Validate(), "name is required")
}
func TestValidate_EmptyDomain(t *testing.T) {
rp := validProxy()
rp.Domain = ""
assert.ErrorContains(t, rp.Validate(), "domain is required")
}
func TestValidate_NoTargets(t *testing.T) {
rp := validProxy()
rp.Targets = nil
assert.ErrorContains(t, rp.Validate(), "at least one target")
}
func TestValidate_EmptyTargetId(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetId = ""
assert.ErrorContains(t, rp.Validate(), "empty target_id")
}
func TestValidate_InvalidTargetType(t *testing.T) {
rp := validProxy()
rp.Targets[0].TargetType = "invalid"
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
}
func TestValidate_ResourceTarget(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "resource-1",
TargetType: TargetTypeHost,
Host: "example.org",
Port: 443,
Protocol: "https",
Enabled: true,
})
require.NoError(t, rp.Validate())
}
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
rp := validProxy()
rp.Targets = append(rp.Targets, &Target{
TargetId: "",
TargetType: TargetTypePeer,
Host: "10.0.0.2",
Port: 80,
Protocol: "http",
Enabled: true,
})
err := rp.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "target 1")
assert.Contains(t, err.Error(), "empty target_id")
}
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string
port int
want bool
}{
{"http", 80, true},
{"https", 443, true},
{"http", 443, false},
{"https", 80, false},
{"http", 8080, false},
{"https", 8443, false},
{"http", 0, false},
{"https", 0, false},
}
for _, tt := range tests {
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
})
}
}
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
oidcConfig := OIDCValidationConfig{}
tests := []struct {
name string
protocol string
host string
port int
wantTarget string
}{
{
name: "http with default port 80 omits port",
protocol: "http",
host: "10.0.0.1",
port: 80,
wantTarget: "http://10.0.0.1/",
},
{
name: "https with default port 443 omits port",
protocol: "https",
host: "10.0.0.1",
port: 443,
wantTarget: "https://10.0.0.1/",
},
{
name: "port 0 omits port",
protocol: "http",
host: "10.0.0.1",
port: 0,
wantTarget: "http://10.0.0.1/",
},
{
name: "non-default port is included",
protocol: "http",
host: "10.0.0.1",
port: 8080,
wantTarget: "http://10.0.0.1:8080/",
},
{
name: "https with non-default port is included",
protocol: "https",
host: "10.0.0.1",
port: 8443,
wantTarget: "https://10.0.0.1:8443/",
},
{
name: "http port 443 is included",
protocol: "http",
host: "10.0.0.1",
port: 443,
wantTarget: "http://10.0.0.1:443/",
},
{
name: "https port 80 is included",
protocol: "https",
host: "10.0.0.1",
port: 80,
wantTarget: "https://10.0.0.1:80/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{
TargetId: "peer-1",
TargetType: TargetTypePeer,
Host: tt.host,
Port: tt.port,
Protocol: tt.protocol,
Enabled: true,
},
},
}
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
require.Len(t, pm.Path, 1, "should have one path mapping")
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
})
}
}
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
rp := &Service{
ID: "test-id",
AccountID: "acc-1",
Domain: "example.com",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
},
}
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
require.Len(t, pm.Path, 1)
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
}
func TestToProtoMapping_OperationTypes(t *testing.T) {
rp := validProxy()
tests := []struct {
op Operation
want proto.ProxyMappingUpdateType
}{
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
}
for _, tt := range tests {
t.Run(string(tt.op), func(t *testing.T) {
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
assert.Equal(t, tt.want, pm.Type)
})
}
}
func TestAuthConfig_HashSecrets(t *testing.T) {
tests := []struct {
name string
config *AuthConfig
wantErr bool
validate func(*testing.T, *AuthConfig)
}{
{
name: "hash password successfully",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "testPassword123",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
}
// Verify the hash can be verified
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash PIN successfully",
config: &AuthConfig{
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "123456",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
}
// Verify the hash can be verified
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
t.Errorf("Hash verification failed: %v", err)
}
},
},
{
name: "hash both password and PIN",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "password",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "9999",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
t.Errorf("Password not hashed with argon2id")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN not hashed with argon2id")
}
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
t.Errorf("Password hash verification failed: %v", err)
}
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
t.Errorf("PIN hash verification failed: %v", err)
}
},
},
{
name: "skip disabled password auth",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: false,
Password: "password",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "password" {
t.Errorf("Disabled password auth should not be hashed")
}
},
},
{
name: "skip empty password",
config: &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth.Password != "" {
t.Errorf("Empty password should remain empty")
}
},
},
{
name: "skip nil password auth",
config: &AuthConfig{
PasswordAuth: nil,
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "1234",
},
},
wantErr: false,
validate: func(t *testing.T, config *AuthConfig) {
if config.PasswordAuth != nil {
t.Errorf("PasswordAuth should remain nil")
}
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
t.Errorf("PIN should still be hashed")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.HashSecrets()
if (err != nil) != tt.wantErr {
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.validate != nil {
tt.validate(t, tt.config)
}
})
}
}
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "correctPassword",
},
}
if err := config.HashSecrets(); err != nil {
t.Fatalf("HashSecrets() error = %v", err)
}
// Verify with wrong password should fail
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
}
}
func TestAuthConfig_ClearSecrets(t *testing.T) {
config := &AuthConfig{
PasswordAuth: &PasswordAuthConfig{
Enabled: true,
Password: "hashedPassword",
},
PinAuth: &PINAuthConfig{
Enabled: true,
Pin: "hashedPin",
},
}
config.ClearSecrets()
if config.PasswordAuth.Password != "" {
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
}
if config.PinAuth.Pin != "" {
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
}
}