Merge branch 'refs/heads/main' into refactor/permissions-manager

# Conflicts:
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/server/http/handler.go
This commit is contained in:
pascal
2026-03-06 11:37:53 +01:00
134 changed files with 13391 additions and 3786 deletions

View File

@@ -15,7 +15,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/shared/auth"
@@ -83,9 +83,9 @@ type DefaultAccountManager struct {
requestBuffer *AccountRequestBuffer
proxyController port_forwarding.Controller
settingsManager settings.Manager
reverseProxyManager reverseproxy.Manager
proxyController port_forwarding.Controller
settingsManager settings.Manager
serviceManager service.Manager
// config contains the management server configuration
config *nbconfig.Config
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
var _ account.Manager = (*DefaultAccountManager)(nil)
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
am.reverseProxyManager = serviceManager
func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
am.serviceManager = serviceManager
}
func isUniqueConstraintError(err error) bool {
@@ -376,6 +376,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID)
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err
}
@@ -394,7 +395,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
}
if reloadReverseProxy {
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
}
}
@@ -492,6 +493,21 @@ func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Con
}
}
func (am *DefaultAccountManager) handlePeerExposeSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
oldEnabled := oldSettings.PeerExposeEnabled
newEnabled := newSettings.PeerExposeEnabled
if oldEnabled == newEnabled {
return
}
event := activity.AccountPeerExposeEnabled
if !newEnabled {
event = activity.AccountPeerExposeDisabled
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
@@ -714,7 +730,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
}
err = am.reverseProxyManager.DeleteAllServices(ctx, accountID, userID)
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
}
@@ -1363,9 +1379,10 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
// We override incoming domain claims to group users under a single account.
userAuth.Domain = am.singleAccountModeDomain
userAuth.DomainCategory = types.PrivateCategory
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
err := am.updateUserAuthWithSingleMode(ctx, &userAuth)
if err != nil {
return "", "", err
}
}
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
@@ -1398,6 +1415,35 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil
}
// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account
func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error {
userAuth.DomainCategory = types.PrivateCategory
userAuth.Domain = am.singleAccountModeDomain
accountID, err := am.Store.GetAnyAccountID(ctx)
if err != nil {
if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound {
return err
}
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
return nil
}
if accountID == "" {
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
return nil
}
domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
userAuth.Domain = domain
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
return nil
}
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager

View File

@@ -1,12 +1,14 @@
package account
//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
import (
"context"
"net"
"net/netip"
"time"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/shared/auth"
nbdns "github.com/netbirdio/netbird/dns"
@@ -61,11 +63,11 @@ type Manager interface {
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
@@ -140,5 +142,5 @@ type Manager interface {
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
SetServiceManager(serviceManager reverseproxy.Manager)
SetServiceManager(serviceManager service.Manager)
}

File diff suppressed because it is too large Load Diff

View File

@@ -15,10 +15,12 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/prometheus/client_golang/prometheus/push"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/metric/noop"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
@@ -27,8 +29,10 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -1803,12 +1807,12 @@ func TestAccount_Copy(t *testing.T) {
Address: "172.12.6.1/24",
},
},
Services: []*reverseproxy.Service{
Services: []*service.Service{
{
ID: "service1",
Name: "test-service",
AccountID: "account1",
Targets: []*reverseproxy.Target{},
Targets: []*service.Target{},
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
@@ -3113,6 +3117,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
proxyManager := proxy.NewMockManager(ctrl)
proxyManager.EXPECT().
CleanupStale(gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
@@ -3123,8 +3133,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil)
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyGrpcServer, nil))
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil {
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
return manager, updateManager, nil
}
@@ -3953,3 +3967,116 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
}
}
func TestUpdateUserAuthWithSingleMode(t *testing.T) {
t.Run("sets defaults and overrides domain from store", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("account-1", nil)
mockStore.EXPECT().
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
Return("real-domain.com", "private", nil)
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "real-domain.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", nil)
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "fallback.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", status.Errorf(status.NotFound, "no accounts"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.NoError(t, err)
assert.Equal(t, "fallback.com", userAuth.Domain)
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("", status.Errorf(status.Internal, "db down"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.Error(t, err)
assert.Contains(t, err.Error(), "db down")
// Defaults should still be set before error path
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
})
t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetAnyAccountID(gomock.Any()).
Return("account-1", nil)
mockStore.EXPECT().
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
Return("", "", status.Errorf(status.Internal, "query failed"))
am := &DefaultAccountManager{
Store: mockStore,
singleAccountModeDomain: "fallback.com",
}
userAuth := &auth.UserAuth{}
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
require.Error(t, err)
assert.Contains(t, err.Error(), "query failed")
})
}

View File

@@ -208,6 +208,18 @@ const (
ServiceUpdated Activity = 109
ServiceDeleted Activity = 110
// PeerServiceExposed indicates that a peer exposed a service via the reverse proxy
PeerServiceExposed Activity = 111
// PeerServiceUnexposed indicates that a peer-exposed service was removed
PeerServiceUnexposed Activity = 112
// PeerServiceExposeExpired indicates that a peer-exposed service was removed due to TTL expiration
PeerServiceExposeExpired Activity = 113
// AccountPeerExposeEnabled indicates that a user enabled peer expose for the account
AccountPeerExposeEnabled Activity = 114
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
AccountPeerExposeDisabled Activity = 115
AccountDeleted Activity = 99999
)
@@ -345,6 +357,13 @@ var activityMap = map[Activity]Code{
ServiceCreated: {"Service created", "service.create"},
ServiceUpdated: {"Service updated", "service.update"},
ServiceDeleted: {"Service deleted", "service.delete"},
PeerServiceExposed: {"Peer exposed service", "service.peer.expose"},
PeerServiceUnexposed: {"Peer unexposed service", "service.peer.unexpose"},
PeerServiceExposeExpired: {"Peer exposed service expired", "service.peer.expose.expire"},
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
}
// StringCode returns a string code of the activity

View File

@@ -249,7 +249,15 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
switch storeEngine {
case types.SqliteStoreEngine:
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB))
dbFile := eventSinkDB
if envFile, ok := os.LookupEnv("NB_ACTIVITY_EVENT_SQLITE_FILE"); ok && envFile != "" {
dbFile = envFile
}
connStr := dbFile
if !filepath.IsAbs(dbFile) {
connStr = filepath.Join(dataDir, dbFile)
}
dialector = sqlite.Open(connStr)
case types.PostgresStoreEngine:
dsn, ok := os.LookupEnv(postgresDsnEnv)
if !ok {

View File

@@ -425,6 +425,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
var groupIDsToDelete []string
var deletedGroups []*types.Group
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
@@ -433,7 +438,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
continue
}
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
@@ -621,7 +626,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
return nil
}
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
@@ -641,6 +646,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
return &GroupLinkError{"network resource", group.Resources[0].ID}
}
if slices.Contains(flowGroups, group.ID) {
return &GroupLinkError{"settings", "traffic event logging"}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}

View File

@@ -12,6 +12,7 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -26,6 +27,7 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
@@ -284,6 +286,67 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
}
}
func TestDefaultAccountManager_DeleteGroupLinkedToFlowGroup(t *testing.T) {
am, _, err := createManager(t)
require.NoError(t, err)
ctrl := gomock.NewController(t)
settingsMock := settings.NewMockManager(ctrl)
settingsMock.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{FlowGroups: []string{"grp-for-flow"}}, nil).
AnyTimes()
settingsMock.EXPECT().
UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(false, nil).
AnyTimes()
am.settingsManager = settingsMock
_, account, err := initTestGroupAccount(am)
require.NoError(t, err)
grp := &types.Group{
ID: "grp-for-flow",
AccountID: account.Id,
Name: "Group for flow",
Issued: types.GroupIssuedAPI,
Peers: make([]string, 0),
}
require.NoError(t, am.CreateGroup(context.Background(), account.Id, groupAdminUserID, grp))
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, "grp-for-flow")
require.Error(t, err)
var gErr *GroupLinkError
require.ErrorAs(t, err, &gErr)
assert.Equal(t, "settings", gErr.Resource)
assert.Equal(t, "traffic event logging", gErr.Name)
group, err := am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
require.NoError(t, err)
assert.NotNil(t, group)
regularGrp := &types.Group{
ID: "grp-regular",
AccountID: account.Id,
Name: "Regular group",
Issued: types.GroupIssuedAPI,
Peers: make([]string, 0),
}
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, regularGrp)
require.NoError(t, err)
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, []string{"grp-for-flow", "grp-regular"})
require.Error(t, err)
group, err = am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
require.NoError(t, err)
assert.NotNil(t, group)
_, err = am.GetGroup(context.Background(), account.Id, "grp-regular", groupAdminUserID)
assert.Error(t, err)
}
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) {
accountID := "testingAcc"
domain := "example.com"
@@ -703,7 +766,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving group linked to network router", func(t *testing.T) {
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
@@ -73,7 +73,7 @@ const (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
idp.AddEndpoints(accountManager, router, permissionsManager)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router, permissionsManager)
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
}
// Register OAuth callback handler for proxy authentication

View File

@@ -163,6 +163,10 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request, userAut
}
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
if req.Settings.PeerExposeEnabled && len(req.Settings.PeerExposeGroups) == 0 {
return nil, status.Errorf(status.InvalidArgument, "peer expose requires at least one group")
}
returnSettings := &types.Settings{
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
@@ -170,6 +174,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
PeerExposeEnabled: req.Settings.PeerExposeEnabled,
PeerExposeGroups: req.Settings.PeerExposeGroups,
}
if req.Settings.Extra != nil {
@@ -317,6 +324,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
JwtAllowGroups: &jwtAllowGroups,
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
PeerExposeEnabled: settings.PeerExposeEnabled,
PeerExposeGroups: settings.PeerExposeGroups,
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain,
AutoUpdateVersion: &settings.AutoUpdateVersion,

View File

@@ -18,8 +18,8 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -190,7 +190,8 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcServer := newFakeOIDCServer()
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
usersManager := users.NewManager(testStore)
@@ -208,9 +209,10 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcConfig,
nil,
usersManager,
nil,
)
proxyService.SetProxyManager(&testServiceManager{store: testStore})
proxyService.SetServiceManager(&testServiceManager{store: testStore})
handler := NewAuthCallbackHandler(proxyService, nil)
@@ -239,12 +241,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
pubKey := base64.StdEncoding.EncodeToString(pub)
privKey := base64.StdEncoding.EncodeToString(priv)
testProxy := &reverseproxy.Service{
testProxy := &service.Service{
ID: "testProxyId",
AccountID: "testAccountId",
Name: "Test Proxy",
Domain: "test-proxy.example.com",
Targets: []*reverseproxy.Target{{
Targets: []*service.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
@@ -254,8 +256,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"allowedGroupId"},
},
@@ -265,12 +267,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
}
require.NoError(t, testStore.CreateService(ctx, testProxy))
restrictedProxy := &reverseproxy.Service{
restrictedProxy := &service.Service{
ID: "restrictedProxyId",
AccountID: "testAccountId",
Name: "Restricted Proxy",
Domain: "restricted-proxy.example.com",
Targets: []*reverseproxy.Target{{
Targets: []*service.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
@@ -280,8 +282,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: true,
DistributionGroups: []string{"restrictedGroupId"},
},
@@ -291,12 +293,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
}
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
noAuthProxy := &reverseproxy.Service{
noAuthProxy := &service.Service{
ID: "noAuthProxyId",
AccountID: "testAccountId",
Name: "No Auth Proxy",
Domain: "no-auth-proxy.example.com",
Targets: []*reverseproxy.Target{{
Targets: []*service.Target{{
Path: strPtr("/"),
Host: "localhost",
Port: 8080,
@@ -306,8 +308,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
Enabled: true,
}},
Enabled: true,
Auth: reverseproxy.AuthConfig{
BearerAuth: &reverseproxy.BearerAuthConfig{
Auth: service.AuthConfig{
BearerAuth: &service.BearerAuthConfig{
Enabled: false,
},
},
@@ -361,19 +363,19 @@ func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, u
return nil
}
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
return nil, nil
}
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
return nil, nil
}
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
return nil, nil
}
@@ -385,7 +387,7 @@ func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ stri
return nil
}
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
return nil
}
@@ -397,15 +399,15 @@ func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error
return nil
}
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
return m.store.GetServices(ctx, store.LockingStrengthNone)
}
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
}
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}
@@ -413,6 +415,20 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri
return "", nil
}
func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
return nil, nil
}
func (m *testServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
t.Helper()

View File

@@ -9,10 +9,13 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"go.opentelemetry.io/otel/metric/noop"
"github.com/netbirdio/management-integrations/integrations"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
@@ -91,12 +94,24 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager)
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager)
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
proxyServiceServer.SetProxyManager(reverseProxyManager)
am.SetServiceManager(reverseProxyManager)
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err)
}
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil {
t.Fatalf("Failed to create proxy controller: %v", err)
}
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager)
proxyServiceServer.SetServiceManager(serviceManager)
am.SetServiceManager(serviceManager)
// @note this is required so that PAT's validate from store, but JWT's are mocked
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
@@ -114,7 +129,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -52,7 +52,7 @@ type EmbeddedIdPConfig struct {
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
type EmbeddedStorageConfig struct {
// Type is the storage type (currently only "sqlite3" is supported)
// Type is the storage type: "sqlite3" (default) or "postgres"
Type string
// Config contains type-specific configuration
Config EmbeddedStorageTypeConfig
@@ -62,6 +62,8 @@ type EmbeddedStorageConfig struct {
type EmbeddedStorageTypeConfig struct {
// File is the path to the SQLite database file (for sqlite3 type)
File string
// DSN is the connection string for postgres
DSN string
}
// OwnerConfig represents the initial owner/admin user for the embedded IdP.
@@ -74,6 +76,22 @@ type OwnerConfig struct {
Username string
}
// buildIdpStorageConfig builds the Dex storage config map based on the storage type.
func buildIdpStorageConfig(storageType string, cfg EmbeddedStorageTypeConfig) (map[string]interface{}, error) {
switch storageType {
case "sqlite3":
return map[string]interface{}{
"file": cfg.File,
}, nil
case "postgres":
return map[string]interface{}{
"dsn": cfg.DSN,
}, nil
default:
return nil, fmt.Errorf("unsupported IdP storage type: %s", storageType)
}
}
// ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig.
func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
if c.Issuer == "" {
@@ -85,6 +103,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" {
return nil, fmt.Errorf("storage file is required for sqlite3")
}
if c.Storage.Type == "postgres" && c.Storage.Config.DSN == "" {
return nil, fmt.Errorf("storage DSN is required for postgres")
}
storageConfig, err := buildIdpStorageConfig(c.Storage.Type, c.Storage.Config)
if err != nil {
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
}
// Build CLI redirect URIs including the device callback (both relative and absolute)
cliRedirectURIs := c.CLIRedirectURIs
@@ -100,10 +126,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
cfg := &dex.YAMLConfig{
Issuer: c.Issuer,
Storage: dex.Storage{
Type: c.Storage.Type,
Config: map[string]interface{}{
"file": c.Storage.Config.File,
},
Type: c.Storage.Type,
Config: storageConfig,
},
Web: dex.Web{
AllowedOrigins: []string{"*"},

View File

@@ -14,6 +14,7 @@ import (
"github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/idp/dex"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/types"
@@ -51,6 +52,7 @@ type properties map[string]interface{}
type DataSource interface {
GetAllAccounts(ctx context.Context) []*types.Account
GetStoreEngine() types.Engine
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
}
// ConnManager peer connection manager that holds state for current active connections
@@ -211,6 +213,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
localUsers int
idpUsers int
embeddedIdpTypes map[string]int
services int
servicesEnabled int
servicesTargets int
servicesStatusActive int
servicesStatusPending int
servicesStatusError int
servicesTargetType map[string]int
servicesAuthPassword int
servicesAuthPin int
servicesAuthOIDC int
)
start := time.Now()
metricsProperties := make(properties)
@@ -220,10 +232,13 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
rulesDirection = make(map[string]int)
activeUsersLastDay = make(map[string]struct{})
embeddedIdpTypes = make(map[string]int)
servicesTargetType = make(map[string]int)
uptime = time.Since(w.startupTime).Seconds()
connections := w.connManager.GetAllConnectedPeers()
version = nbversion.NetbirdVersion()
customDomains, customDomainsValidated, _ := w.dataSource.GetCustomDomainsCounts(ctx)
for _, account := range w.dataSource.GetAllAccounts(ctx) {
accounts++
@@ -279,9 +294,9 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
localUsers++
} else {
idpUsers++
idpType := extractIdpType(idpID)
embeddedIdpTypes[idpType]++
}
idpType := extractIdpType(idpID)
embeddedIdpTypes[idpType]++
}
}
}
@@ -335,6 +350,37 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
peerActiveVersions = append(peerActiveVersions, peer.Meta.WtVersion)
}
}
for _, service := range account.Services {
services++
if service.Enabled {
servicesEnabled++
}
servicesTargets += len(service.Targets)
switch rpservice.Status(service.Meta.Status) {
case rpservice.StatusActive:
servicesStatusActive++
case rpservice.StatusPending:
servicesStatusPending++
case rpservice.StatusError, rpservice.StatusCertificateFailed, rpservice.StatusTunnelNotCreated:
servicesStatusError++
}
for _, target := range service.Targets {
servicesTargetType[target.TargetType]++
}
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled {
servicesAuthPassword++
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled {
servicesAuthPin++
}
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
servicesAuthOIDC++
}
}
}
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
@@ -375,6 +421,22 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["idp_users_count"] = idpUsers
metricsProperties["embedded_idp_count"] = len(embeddedIdpTypes)
metricsProperties["services"] = services
metricsProperties["services_enabled"] = servicesEnabled
metricsProperties["services_targets"] = servicesTargets
metricsProperties["services_status_active"] = servicesStatusActive
metricsProperties["services_status_pending"] = servicesStatusPending
metricsProperties["services_status_error"] = servicesStatusError
metricsProperties["services_auth_password"] = servicesAuthPassword
metricsProperties["services_auth_pin"] = servicesAuthPin
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
metricsProperties["custom_domains"] = customDomains
metricsProperties["custom_domains_validated"] = customDomainsValidated
for targetType, count := range servicesTargetType {
metricsProperties["services_target_type_"+targetType] = count
}
for idpType, count := range embeddedIdpTypes {
metricsProperties["embedded_idp_users_"+idpType] = count
}
@@ -469,6 +531,9 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
// Connector IDs are formatted as "<type>-<xid>" (e.g., "okta-abc123", "zitadel-xyz").
// Returns the type prefix, or "oidc" if no known prefix is found.
func extractIdpType(connectorID string) string {
if connectorID == "local" {
return "local"
}
idx := strings.LastIndex(connectorID, "-")
if idx <= 0 {
return "oidc"

View File

@@ -6,6 +6,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/idp/dex"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -28,6 +29,7 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
localUserID := dex.EncodeDexUserID("10", "local")
idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0")
oidcUserID := dex.EncodeDexUserID("30", "d6jvvp69kmnc73c9pl40")
return []*types.Account{
{
Id: "1",
@@ -115,6 +117,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
},
},
},
Services: []*rpservice.Service{
{
ID: "svc1",
Enabled: true,
Targets: []*rpservice.Target{
{TargetType: "peer"},
{TargetType: "host"},
},
Auth: rpservice.AuthConfig{
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
},
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
},
{
ID: "svc2",
Enabled: false,
Targets: []*rpservice.Target{
{TargetType: "domain"},
},
Auth: rpservice.AuthConfig{
BearerAuth: &rpservice.BearerAuthConfig{Enabled: true},
},
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
},
},
},
{
Id: "2",
@@ -180,6 +207,13 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
"1": {},
},
},
oidcUserID: {
Id: oidcUserID,
IsServiceUser: false,
PATs: map[string]*types.PersonalAccessToken{
"1": {},
},
},
},
Networks: []*networkTypes.Network{
{
@@ -215,6 +249,11 @@ func (mockDatasource) GetStoreEngine() types.Engine {
return types.FileStoreEngine
}
// GetCustomDomainsCounts returns test custom domain counts.
func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) {
return 3, 2, nil
}
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
func TestGenerateProperties(t *testing.T) {
ds := mockDatasource{}
@@ -247,14 +286,14 @@ func TestGenerateProperties(t *testing.T) {
if properties["rules"] != 4 {
t.Errorf("expected 4 rules, got %d", properties["rules"])
}
if properties["users"] != 2 {
t.Errorf("expected 1 users, got %d", properties["users"])
if properties["users"] != 3 {
t.Errorf("expected 3 users, got %d", properties["users"])
}
if properties["setup_keys_usage"] != 2 {
t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"])
}
if properties["pats"] != 4 {
t.Errorf("expected 4 personal_access_tokens, got %d", properties["pats"])
if properties["pats"] != 5 {
t.Errorf("expected 5 personal_access_tokens, got %d", properties["pats"])
}
if properties["peers_ssh_enabled"] != 2 {
t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"])
@@ -338,14 +377,63 @@ func TestGenerateProperties(t *testing.T) {
if properties["local_users_count"] != 1 {
t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"])
}
if properties["idp_users_count"] != 1 {
t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"])
if properties["idp_users_count"] != 2 {
t.Errorf("expected 2 idp_users_count, got %d", properties["idp_users_count"])
}
if properties["embedded_idp_users_local"] != 1 {
t.Errorf("expected 1 embedded_idp_users_local, got %v", properties["embedded_idp_users_local"])
}
if properties["embedded_idp_users_zitadel"] != 1 {
t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"])
}
if properties["embedded_idp_count"] != 1 {
t.Errorf("expected 1 embedded_idp_count, got %v", properties["embedded_idp_count"])
if properties["embedded_idp_users_oidc"] != 1 {
t.Errorf("expected 1 embedded_idp_users_oidc, got %v", properties["embedded_idp_users_oidc"])
}
if properties["embedded_idp_count"] != 3 {
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
}
if properties["services"] != 2 {
t.Errorf("expected 2 services, got %v", properties["services"])
}
if properties["services_enabled"] != 1 {
t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"])
}
if properties["services_targets"] != 3 {
t.Errorf("expected 3 services_targets, got %v", properties["services_targets"])
}
if properties["services_status_active"] != 1 {
t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"])
}
if properties["services_status_pending"] != 1 {
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
}
if properties["services_status_error"] != 0 {
t.Errorf("expected 0 services_status_error, got %v", properties["services_status_error"])
}
if properties["services_target_type_peer"] != 1 {
t.Errorf("expected 1 services_target_type_peer, got %v", properties["services_target_type_peer"])
}
if properties["services_target_type_host"] != 1 {
t.Errorf("expected 1 services_target_type_host, got %v", properties["services_target_type_host"])
}
if properties["services_target_type_domain"] != 1 {
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
}
if properties["services_auth_password"] != 1 {
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
}
if properties["services_auth_oidc"] != 1 {
t.Errorf("expected 1 services_auth_oidc, got %v", properties["services_auth_oidc"])
}
if properties["services_auth_pin"] != 0 {
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
}
if properties["custom_domains"] != int64(3) {
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
}
if properties["custom_domains_validated"] != int64(2) {
t.Errorf("expected 2 custom_domains_validated, got %v", properties["custom_domains_validated"])
}
}
@@ -362,7 +450,8 @@ func TestExtractIdpType(t *testing.T) {
{"microsoft-abc123", "microsoft"},
{"authentik-abc123", "authentik"},
{"keycloak-d5uv82dra0haedlf6kv0", "keycloak"},
{"local", "oidc"},
{"local", "local"},
{"d6jvvp69kmnc73c9pl40", "oidc"},
{"", "oidc"},
}

View File

@@ -12,7 +12,7 @@ import (
"google.golang.org/grpc/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
@@ -148,7 +148,7 @@ type MockAccountManager struct {
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
}
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) {
// Mock implementation - no-op
}
@@ -407,7 +407,7 @@ func (am *MockAccountManager) AddPeer(
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
if am.GetGroupFunc != nil {
if am.GetGroupByNameFunc != nil {
return am.GetGroupByNameFunc(ctx, accountID, groupName)
}
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")

View File

@@ -7,7 +7,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
@@ -33,23 +33,23 @@ type Manager interface {
}
type managerImpl struct {
store store.Store
permissionsManager permissions.Manager
groupsManager groups.Manager
accountManager account.Manager
reverseProxyManager reverseproxy.Manager
store store.Store
permissionsManager permissions.Manager
groupsManager groups.Manager
accountManager account.Manager
serviceManager service.Manager
}
type mockManager struct {
}
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager {
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager {
return &managerImpl{
store: store,
permissionsManager: permissionsManager,
groupsManager: groupsManager,
accountManager: accountManager,
reverseProxyManager: reverseproxyManager,
store: store,
permissionsManager: permissionsManager,
groupsManager: groupsManager,
accountManager: accountManager,
serviceManager: reverseproxyManager,
}
}
@@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
go func() {
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
if err != nil {
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
}
@@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return status.NewPermissionDeniedError()
}
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err)
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -31,8 +31,8 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.NoError(t, err)
@@ -54,8 +54,8 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
require.Error(t, err)
@@ -76,8 +76,8 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.NoError(t, err)
@@ -98,8 +98,8 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
require.Error(t, err)
@@ -123,8 +123,8 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err)
@@ -147,8 +147,8 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err)
@@ -176,9 +176,9 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.NoError(t, err)
@@ -205,8 +205,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -234,8 +234,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -262,8 +262,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
createdResource, err := manager.CreateResource(ctx, userID, resource)
require.Error(t, err)
@@ -294,9 +294,9 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.NoError(t, err)
@@ -329,8 +329,8 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -361,8 +361,8 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -392,8 +392,8 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
require.Error(t, err)
@@ -416,9 +416,9 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.NoError(t, err)
@@ -440,8 +440,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
am := mock_server.MockAccountManager{}
groupsManager := groups.NewManagerMock()
ctrl := gomock.NewController(t)
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
serviceManager := reverseproxy.NewMockManager(ctrl)
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
require.Error(t, err)

View File

@@ -493,7 +493,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var settings *types.Settings
var eventsToStore []func()
serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID)
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err)
}

View File

@@ -352,9 +352,10 @@ func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest)
p.Name = a.Name
p.Key = a.WgPubKey
p.Meta = PeerSystemMeta{
Hostname: a.Name,
GoOS: "js",
OS: "js",
Hostname: a.Name,
GoOS: "js",
OS: "js",
KernelVersion: "wasm",
}
}

View File

@@ -269,3 +269,8 @@ func (s *FileStore) GetStoreEngine() types.Engine {
func (s *FileStore) SetFieldEncrypt(_ *crypt.FieldEncrypt) {
// no-op: FileStore stores data in plaintext JSON; encryption is not supported
}
// GetCustomDomainsCounts is a no-op for FileStore as it doesn't support custom domains.
func (s *FileStore) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) {
return 0, 0, nil
}

View File

@@ -28,9 +28,10 @@ import (
"gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -131,8 +132,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{},
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -1007,6 +1008,18 @@ func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
return count, nil
}
// GetCustomDomainsCounts returns the total and validated custom domain counts.
func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
var total, validated int64
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Count(&total).Error; err != nil {
return 0, 0, err
}
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
return 0, 0, err
}
return total, validated, nil
}
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
var accounts []types.Account
result := s.db.Find(&accounts)
@@ -2063,7 +2076,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
return checks, nil
}
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key
@@ -2078,8 +2091,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
return nil, err
}
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) {
var s reverseproxy.Service
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
var s rpservice.Service
var auth []byte
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
@@ -2109,12 +2122,13 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
}
}
s.Meta = reverseproxy.ServiceMeta{}
s.Meta = rpservice.Meta{}
if createdAt.Valid {
s.Meta.CreatedAt = createdAt.Time
}
if certIssuedAt.Valid {
s.Meta.CertificateIssuedAt = certIssuedAt.Time
t := certIssuedAt.Time
s.Meta.CertificateIssuedAt = &t
}
if status.Valid {
s.Meta.Status = status.String
@@ -2129,7 +2143,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
s.SessionPublicKey = sessionPublicKey.String
}
s.Targets = []*reverseproxy.Target{}
s.Targets = []*rpservice.Target{}
return &s, nil
})
if err != nil {
@@ -2141,7 +2155,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
}
serviceIDs := make([]string, len(services))
serviceMap := make(map[string]*reverseproxy.Service)
serviceMap := make(map[string]*rpservice.Service)
for i, s := range services {
serviceIDs[i] = s.ID
serviceMap[s.ID] = s
@@ -2152,8 +2166,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
return nil, err
}
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) {
var t reverseproxy.Target
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) {
var t rpservice.Target
var path sql.NullString
err := row.Scan(
&t.ID,
@@ -2715,14 +2729,28 @@ func (s *SqlStore) GetStoreEngine() types.Engine {
// NewSqliteStore creates a new SQLite store.
func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
if runtime.GOOS == "windows" {
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
storeStr = storeSqliteFileName
storeFile := storeSqliteFileName
if envFile, ok := os.LookupEnv("NB_STORE_ENGINE_SQLITE_FILE"); ok && envFile != "" {
storeFile = envFile
}
file := filepath.Join(dataDir, storeStr)
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
// Separate file path from any SQLite URI query parameters (e.g., "store.db?mode=rwc")
filePath, query, hasQuery := strings.Cut(storeFile, "?")
connStr := filePath
if !filepath.IsAbs(filePath) {
connStr = filepath.Join(dataDir, filePath)
}
// Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows
if hasQuery {
connStr += "?" + query
} else if runtime.GOOS != "windows" {
// To avoid `The process cannot access the file because it is being used by another process` on Windows
connStr += "?cache=shared"
}
db, err := gorm.Open(sqlite.Open(connStr), getGormConfig())
if err != nil {
return nil, err
}
@@ -4825,7 +4853,7 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
return peerID, nil
}
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error {
serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err)
@@ -4839,16 +4867,19 @@ func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Serv
return nil
}
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error {
serviceCopy := service.Copy()
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
return fmt.Errorf("encrypt service data: %w", err)
}
// Create target type instance outside transaction to avoid variable shadowing
targetType := &rpservice.Target{}
// Use a transaction to ensure atomic updates of the service and its targets
err := s.db.Transaction(func(tx *gorm.DB) error {
// Delete existing targets
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil {
return err
}
@@ -4869,7 +4900,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Serv
}
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID)
result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete service from store")
@@ -4882,13 +4913,53 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
return nil
}
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error {
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete target from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "target not found for service %s", serviceID)
}
return nil
}
func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error {
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete targets from store")
}
return nil
}
// GetTargetsByServiceID retrieves all targets for a given service
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) {
var targets []*rpservice.Target
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
result := tx.Where("account_id = ? AND service_id = ?", accountID, serviceID).Find(&targets)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get targets from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get targets from store")
}
return targets, nil
}
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var service *reverseproxy.Service
var service *rpservice.Service
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -4906,30 +4977,8 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
return service, nil
}
func (s *SqlStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var serviceList []*reverseproxy.Service
result := tx.Find(&serviceList, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get services from store")
}
for _, service := range serviceList {
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
return nil, fmt.Errorf("decrypt service data: %w", err)
}
}
return serviceList, nil
}
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
var service *reverseproxy.Service
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
var service *rpservice.Service
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -4947,13 +4996,13 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str
return service, nil
}
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var serviceList []*reverseproxy.Service
var serviceList []*rpservice.Service
result := tx.Find(&serviceList)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
@@ -4969,13 +5018,13 @@ func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength
return serviceList, nil
}
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) {
tx := s.db.Preload("Targets")
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var serviceList []*reverseproxy.Service
var serviceList []*rpservice.Service
result := tx.Find(&serviceList, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
@@ -4991,6 +5040,99 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS
return serviceList, nil
}
// RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service.
func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Update("meta_last_renewed_at", time.Now())
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error)
return status.Errorf(status.Internal, "renew ephemeral service")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
}
return nil
}
// GetExpiredEphemeralServices returns ephemeral services whose last renewal exceeds the given TTL.
// Only the fields needed for reaping are selected. The limit parameter caps the batch size to
// avoid loading too many rows in a single tick. Rows with empty source_peer are excluded to
// skip malformed legacy data.
func (s *SqlStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) {
cutoff := time.Now().Add(-ttl)
var services []*rpservice.Service
result := s.db.
Select("id", "account_id", "source_peer", "domain").
Where("source = ? AND source_peer <> '' AND meta_last_renewed_at < ?", rpservice.SourceEphemeral, cutoff).
Limit(limit).
Find(&services)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get expired ephemeral services: %v", result.Error)
return nil, status.Errorf(status.Internal, "get expired ephemeral services")
}
return services, nil
}
// CountEphemeralServicesByPeer returns the count of ephemeral services for a specific peer.
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
// The locking is applied via a row-level SELECT ... FOR UPDATE (not on the aggregate) to
// stay compatible with Postgres, which disallows FOR UPDATE on COUNT(*).
func (s *SqlStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
if lockStrength == LockingStrengthNone {
var count int64
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
Count(&count)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
return 0, status.Errorf(status.Internal, "count ephemeral services")
}
return count, nil
}
var ids []string
result := s.db.Model(&rpservice.Service{}).
Clauses(clause.Locking{Strength: string(lockStrength)}).
Select("id").
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
Pluck("id", &ids)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
return 0, status.Errorf(status.Internal, "count ephemeral services")
}
return int64(len(ids)), nil
}
// EphemeralServiceExists checks if an ephemeral service exists for the given peer and domain.
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
if lockStrength == LockingStrengthNone {
var count int64
result := s.db.Model(&rpservice.Service{}).
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Count(&count)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
return false, status.Errorf(status.Internal, "check ephemeral service existence")
}
return count > 0, nil
}
var id string
result := s.db.Model(&rpservice.Service{}).
Clauses(clause.Locking{Strength: string(lockStrength)}).
Select("id").
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
Limit(1).
Pluck("id", &id)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
return false, status.Errorf(status.Internal, "check ephemeral service existence")
}
return id != "", nil
}
func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) {
tx := s.db
@@ -5203,13 +5345,13 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
return query
}
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var target *reverseproxy.Target
var target *rpservice.Target
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -5222,3 +5364,65 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
return target, nil
}
// SaveProxy saves or updates a proxy in the database
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
result := s.db.WithContext(ctx).Save(p)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to save proxy")
}
return nil
}
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", time.Now())
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
return status.Errorf(status.Internal, "failed to update proxy heartbeat")
}
return nil
}
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
var addresses []string
result := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
Distinct("cluster_address").
Pluck("cluster_address", &addresses)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses")
}
return addresses, nil
}
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)
result := s.db.WithContext(ctx).
Where("last_seen < ?", cutoffTime).
Delete(&proxy.Proxy{})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error)
return status.Errorf(status.Internal, "failed to cleanup stale proxies")
}
if result.RowsAffected > 0 {
log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected)
}
return nil
}

View File

@@ -20,7 +20,7 @@ import (
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -264,7 +264,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{},
&types.AccountOnboarding{}, &service.Service{}, &service.Target{},
}
for i := len(models) - 1; i >= 0; i-- {

View File

@@ -25,9 +25,10 @@ import (
"gorm.io/gorm"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/telemetry"
@@ -252,14 +253,18 @@ type Store interface {
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
CreateService(ctx context.Context, service *reverseproxy.Service) error
UpdateService(ctx context.Context, service *reverseproxy.Service) error
CreateService(ctx context.Context, service *rpservice.Service) error
UpdateService(ctx context.Context, service *rpservice.Service) error
DeleteService(ctx context.Context, accountID, serviceID string) error
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error)
GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error
GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error)
CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error)
EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error)
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
@@ -271,7 +276,17 @@ type Store interface {
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error)
GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error)
DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
}
const (

View File

@@ -12,9 +12,10 @@ import (
gomock "github.com/golang/mock/gomock"
dns "github.com/netbirdio/netbird/dns"
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
zones "github.com/netbirdio/netbird/management/internals/modules/zones"
records "github.com/netbirdio/netbird/management/internals/modules/zones/records"
types "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -150,6 +151,20 @@ func (mr *MockStoreMockRecorder) ApproveAccountPeers(ctx, accountID interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID)
}
// CleanupStaleProxies mocks base method.
func (m *MockStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CleanupStaleProxies", ctx, inactivityDuration)
ret0, _ := ret[0].(error)
return ret0
}
// CleanupStaleProxies indicates an expected call of CleanupStaleProxies.
func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
}
// Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper()
@@ -193,6 +208,21 @@ func (mr *MockStoreMockRecorder) CountAccountsByPrivateDomain(ctx, domain interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountsByPrivateDomain", reflect.TypeOf((*MockStore)(nil).CountAccountsByPrivateDomain), ctx, domain)
}
// CountEphemeralServicesByPeer mocks base method.
func (m *MockStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CountEphemeralServicesByPeer", ctx, lockStrength, accountID, peerID)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CountEphemeralServicesByPeer indicates an expected call of CountEphemeralServicesByPeer.
func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
}
// CreateAccessLog mocks base method.
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
m.ctrl.T.Helper()
@@ -293,7 +323,7 @@ func (mr *MockStoreMockRecorder) CreatePolicy(ctx, policy interface{}) *gomock.C
}
// CreateService mocks base method.
func (m *MockStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
func (m *MockStore) CreateService(ctx context.Context, service *service.Service) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateService", ctx, service)
ret0, _ := ret[0].(error)
@@ -559,6 +589,20 @@ func (mr *MockStoreMockRecorder) DeleteService(ctx, accountID, serviceID interfa
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockStore)(nil).DeleteService), ctx, accountID, serviceID)
}
// DeleteServiceTargets mocks base method.
func (m *MockStore) DeleteServiceTargets(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteServiceTargets", ctx, accountID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteServiceTargets indicates an expected call of DeleteServiceTargets.
func (mr *MockStoreMockRecorder) DeleteServiceTargets(ctx, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteServiceTargets", reflect.TypeOf((*MockStore)(nil).DeleteServiceTargets), ctx, accountID, serviceID)
}
// DeleteSetupKey mocks base method.
func (m *MockStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
m.ctrl.T.Helper()
@@ -573,6 +617,20 @@ func (mr *MockStoreMockRecorder) DeleteSetupKey(ctx, accountID, keyID interface{
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSetupKey", reflect.TypeOf((*MockStore)(nil).DeleteSetupKey), ctx, accountID, keyID)
}
// DeleteTarget mocks base method.
func (m *MockStore) DeleteTarget(ctx context.Context, accountID, serviceID string, targetID uint) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteTarget", ctx, accountID, serviceID, targetID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteTarget indicates an expected call of DeleteTarget.
func (mr *MockStoreMockRecorder) DeleteTarget(ctx, accountID, serviceID, targetID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTarget", reflect.TypeOf((*MockStore)(nil).DeleteTarget), ctx, accountID, serviceID, targetID)
}
// DeleteTokenID2UserIDIndex mocks base method.
func (m *MockStore) DeleteTokenID2UserIDIndex(tokenID string) error {
m.ctrl.T.Helper()
@@ -643,6 +701,21 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
}
// EphemeralServiceExists mocks base method.
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "EphemeralServiceExists", ctx, lockStrength, accountID, peerID, domain)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// EphemeralServiceExists indicates an expected call of EphemeralServiceExists.
func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
}
// ExecuteInTransaction mocks base method.
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
m.ctrl.T.Helper()
@@ -1095,10 +1168,10 @@ func (mr *MockStoreMockRecorder) GetAccountRoutes(ctx, lockStrength, accountID i
}
// GetAccountServices mocks base method.
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1109,21 +1182,6 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
}
// GetServicesByAccountID mocks base method.
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
}
// GetAccountSettings mocks base method.
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
m.ctrl.T.Helper()
@@ -1214,6 +1272,21 @@ func (mr *MockStoreMockRecorder) GetAccountsCounter(ctx interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx)
}
// GetActiveProxyClusterAddresses mocks base method.
func (m *MockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddresses", ctx)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveProxyClusterAddresses indicates an expected call of GetActiveProxyClusterAddresses.
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
}
// GetAllAccounts mocks base method.
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
m.ctrl.T.Helper()
@@ -1288,6 +1361,22 @@ func (mr *MockStoreMockRecorder) GetCustomDomain(ctx, accountID, domainID interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomain", reflect.TypeOf((*MockStore)(nil).GetCustomDomain), ctx, accountID, domainID)
}
// GetCustomDomainsCounts mocks base method.
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(int64)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
}
// GetDNSRecordByID mocks base method.
func (m *MockStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
m.ctrl.T.Helper()
@@ -1303,6 +1392,21 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID,
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID)
}
// GetExpiredEphemeralServices mocks base method.
func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetExpiredEphemeralServices", ctx, ttl, limit)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetExpiredEphemeralServices indicates an expected call of GetExpiredEphemeralServices.
func (mr *MockStoreMockRecorder) GetExpiredEphemeralServices(ctx, ttl, limit interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiredEphemeralServices", reflect.TypeOf((*MockStore)(nil).GetExpiredEphemeralServices), ctx, ttl, limit)
}
// GetGroupByID mocks base method.
func (m *MockStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types2.Group, error) {
m.ctrl.T.Helper()
@@ -1828,10 +1932,10 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
}
// GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
ret0, _ := ret[0].(*reverseproxy.Service)
ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1843,10 +1947,10 @@ func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain inter
}
// GetServiceByID mocks base method.
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID)
ret0, _ := ret[0].(*reverseproxy.Service)
ret0, _ := ret[0].(*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1858,10 +1962,10 @@ func (mr *MockStoreMockRecorder) GetServiceByID(ctx, lockStrength, accountID, se
}
// GetServiceTargetByTargetID mocks base method.
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*reverseproxy.Target, error) {
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*service.Target, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID)
ret0, _ := ret[0].(*reverseproxy.Target)
ret0, _ := ret[0].(*service.Target)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1873,10 +1977,10 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a
}
// GetServices mocks base method.
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*service.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret0, _ := ret[0].([]*service.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
@@ -1946,6 +2050,21 @@ func (mr *MockStoreMockRecorder) GetTakenIPs(ctx, lockStrength, accountId interf
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTakenIPs", reflect.TypeOf((*MockStore)(nil).GetTakenIPs), ctx, lockStrength, accountId)
}
// GetTargetsByServiceID mocks base method.
func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*service.Target, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTargetsByServiceID", ctx, lockStrength, accountID, serviceID)
ret0, _ := ret[0].([]*service.Target)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTargetsByServiceID indicates an expected call of GetTargetsByServiceID.
func (mr *MockStoreMockRecorder) GetTargetsByServiceID(ctx, lockStrength, accountID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTargetsByServiceID", reflect.TypeOf((*MockStore)(nil).GetTargetsByServiceID), ctx, lockStrength, accountID, serviceID)
}
// GetTokenIDByHashedToken mocks base method.
func (m *MockStore) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) {
m.ctrl.T.Helper()
@@ -2327,6 +2446,20 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResourceFromGroup", reflect.TypeOf((*MockStore)(nil).RemoveResourceFromGroup), ctx, accountId, groupID, resourceID)
}
// RenewEphemeralService mocks base method.
func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// RenewEphemeralService indicates an expected call of RenewEphemeralService.
func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, domain)
}
// RevokeProxyAccessToken mocks base method.
func (m *MockStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
m.ctrl.T.Helper()
@@ -2551,6 +2684,20 @@ func (mr *MockStoreMockRecorder) SavePostureChecks(ctx, postureCheck interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck)
}
// SaveProxy mocks base method.
func (m *MockStore) SaveProxy(ctx context.Context, proxy *proxy.Proxy) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SaveProxy", ctx, proxy)
ret0, _ := ret[0].(error)
return ret0
}
// SaveProxy indicates an expected call of SaveProxy.
func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
}
// SaveProxyAccessToken mocks base method.
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
m.ctrl.T.Helper()
@@ -2746,8 +2893,22 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups)
}
// UpdateProxyHeartbeat mocks base method.
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID)
}
// UpdateService mocks base method.
func (m *MockStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
func (m *MockStore) UpdateService(ctx context.Context, service *service.Service) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateService", ctx, service)
ret0, _ := ret[0].(error)

View File

@@ -0,0 +1,185 @@
package telemetry
import (
"context"
"math"
"sync"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
)
// AccountDurationAggregator uses OpenTelemetry histograms per account to calculate P95
// without publishing individual account labels
type AccountDurationAggregator struct {
mu sync.RWMutex
accounts map[string]*accountHistogram
meterProvider *sdkmetric.MeterProvider
manualReader *sdkmetric.ManualReader
FlushInterval time.Duration
MaxAge time.Duration
ctx context.Context
}
type accountHistogram struct {
histogram metric.Int64Histogram
lastUpdate time.Time
}
// NewAccountDurationAggregator creates aggregator using OTel histograms
func NewAccountDurationAggregator(ctx context.Context, flushInterval, maxAge time.Duration) *AccountDurationAggregator {
manualReader := sdkmetric.NewManualReader(
sdkmetric.WithTemporalitySelector(func(kind sdkmetric.InstrumentKind) metricdata.Temporality {
return metricdata.DeltaTemporality
}),
)
meterProvider := sdkmetric.NewMeterProvider(
sdkmetric.WithReader(manualReader),
)
return &AccountDurationAggregator{
accounts: make(map[string]*accountHistogram),
meterProvider: meterProvider,
manualReader: manualReader,
FlushInterval: flushInterval,
MaxAge: maxAge,
ctx: ctx,
}
}
// Record adds a duration for an account using OTel histogram
func (a *AccountDurationAggregator) Record(accountID string, duration time.Duration) {
a.mu.Lock()
defer a.mu.Unlock()
accHist, exists := a.accounts[accountID]
if !exists {
meter := a.meterProvider.Meter("account-aggregator")
histogram, err := meter.Int64Histogram(
"sync_duration_per_account",
metric.WithUnit("milliseconds"),
)
if err != nil {
return
}
accHist = &accountHistogram{
histogram: histogram,
}
a.accounts[accountID] = accHist
}
accHist.histogram.Record(a.ctx, duration.Milliseconds(),
metric.WithAttributes(attribute.String("account_id", accountID)))
accHist.lastUpdate = time.Now()
}
// FlushAndGetP95s extracts P95 from each account's histogram
func (a *AccountDurationAggregator) FlushAndGetP95s() []int64 {
a.mu.Lock()
defer a.mu.Unlock()
var rm metricdata.ResourceMetrics
err := a.manualReader.Collect(a.ctx, &rm)
if err != nil {
return nil
}
now := time.Now()
p95s := make([]int64, 0, len(a.accounts))
for _, scopeMetrics := range rm.ScopeMetrics {
for _, metric := range scopeMetrics.Metrics {
histogramData, ok := metric.Data.(metricdata.Histogram[int64])
if !ok {
continue
}
for _, dataPoint := range histogramData.DataPoints {
a.processDataPoint(dataPoint, now, &p95s)
}
}
}
a.cleanupStaleAccounts(now)
return p95s
}
// processDataPoint extracts P95 from a single histogram data point
func (a *AccountDurationAggregator) processDataPoint(dataPoint metricdata.HistogramDataPoint[int64], now time.Time, p95s *[]int64) {
accountID := extractAccountID(dataPoint)
if accountID == "" {
return
}
if p95 := calculateP95FromHistogram(dataPoint); p95 > 0 {
*p95s = append(*p95s, p95)
}
}
// cleanupStaleAccounts removes accounts that haven't been updated recently
func (a *AccountDurationAggregator) cleanupStaleAccounts(now time.Time) {
for accountID := range a.accounts {
if a.isStaleAccount(accountID, now) {
delete(a.accounts, accountID)
}
}
}
// extractAccountID retrieves the account_id from histogram data point attributes
func extractAccountID(dp metricdata.HistogramDataPoint[int64]) string {
for _, attr := range dp.Attributes.ToSlice() {
if attr.Key == "account_id" {
return attr.Value.AsString()
}
}
return ""
}
// isStaleAccount checks if an account hasn't been updated recently
func (a *AccountDurationAggregator) isStaleAccount(accountID string, now time.Time) bool {
accHist, exists := a.accounts[accountID]
if !exists {
return false
}
return now.Sub(accHist.lastUpdate) > a.MaxAge
}
// calculateP95FromHistogram computes P95 from OTel histogram data
func calculateP95FromHistogram(dp metricdata.HistogramDataPoint[int64]) int64 {
if dp.Count == 0 {
return 0
}
targetCount := uint64(math.Ceil(float64(dp.Count) * 0.95))
if targetCount == 0 {
targetCount = 1
}
var cumulativeCount uint64
for i, bucketCount := range dp.BucketCounts {
cumulativeCount += bucketCount
if cumulativeCount >= targetCount {
if i < len(dp.Bounds) {
return int64(dp.Bounds[i])
}
if maxVal, defined := dp.Max.Value(); defined {
return maxVal
}
return dp.Sum / int64(dp.Count)
}
}
return dp.Sum / int64(dp.Count)
}
// Shutdown cleans up resources
func (a *AccountDurationAggregator) Shutdown() error {
return a.meterProvider.Shutdown(a.ctx)
}

View File

@@ -0,0 +1,219 @@
package telemetry
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDeltaTemporality_P95ReflectsCurrentWindow(t *testing.T) {
// Verify that with delta temporality, each flush window only reflects
// recordings since the last flush — not all-time data.
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// Window 1: Record 100 slow requests (500ms each)
for range 100 {
agg.Record("account-A", 500*time.Millisecond)
}
p95sWindow1 := agg.FlushAndGetP95s()
require.Len(t, p95sWindow1, 1, "should have P95 for one account")
firstP95 := p95sWindow1[0]
assert.GreaterOrEqual(t, firstP95, int64(200),
"first window P95 should reflect the 500ms recordings")
// Window 2: Record 100 FAST requests (10ms each)
for range 100 {
agg.Record("account-A", 10*time.Millisecond)
}
p95sWindow2 := agg.FlushAndGetP95s()
require.Len(t, p95sWindow2, 1, "should have P95 for one account")
secondP95 := p95sWindow2[0]
// With delta temporality the P95 should drop significantly because
// the first window's slow recordings are no longer included.
assert.Less(t, secondP95, firstP95,
"second window P95 should be lower than first — delta temporality "+
"ensures each window only reflects recent recordings")
}
func TestEqualWeightPerAccount(t *testing.T) {
// Verify that each account contributes exactly one P95 value,
// regardless of how many requests it made.
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// Account A: 10,000 requests at 500ms (noisy customer)
for range 10000 {
agg.Record("account-A", 500*time.Millisecond)
}
// Accounts B, C, D: 10 requests each at 50ms (normal customers)
for _, id := range []string{"account-B", "account-C", "account-D"} {
for range 10 {
agg.Record(id, 50*time.Millisecond)
}
}
p95s := agg.FlushAndGetP95s()
// Should get exactly 4 P95 values — one per account
assert.Len(t, p95s, 4, "each account should contribute exactly one P95")
}
func TestStaleAccountEviction(t *testing.T) {
ctx := context.Background()
// Use a very short MaxAge so we can test staleness
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
agg.Record("account-A", 100*time.Millisecond)
agg.Record("account-B", 200*time.Millisecond)
// Both accounts should appear
p95s := agg.FlushAndGetP95s()
assert.Len(t, p95s, 2, "both accounts should have P95 values")
// Wait for account-A to become stale, then only update account-B
time.Sleep(60 * time.Millisecond)
agg.Record("account-B", 200*time.Millisecond)
p95s = agg.FlushAndGetP95s()
assert.Len(t, p95s, 1, "both accounts should have P95 values")
// account-A should have been evicted from the accounts map
agg.mu.RLock()
_, accountAExists := agg.accounts["account-A"]
_, accountBExists := agg.accounts["account-B"]
agg.mu.RUnlock()
assert.False(t, accountAExists, "stale account-A should be evicted from map")
assert.True(t, accountBExists, "active account-B should remain in map")
}
func TestStaleAccountEviction_DoesNotReappear(t *testing.T) {
// Verify that with delta temporality, an evicted stale account does not
// reappear in subsequent flushes.
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
agg.Record("account-stale", 100*time.Millisecond)
// Wait for it to become stale
time.Sleep(60 * time.Millisecond)
// First flush: should detect staleness and evict
_ = agg.FlushAndGetP95s()
agg.mu.RLock()
_, exists := agg.accounts["account-stale"]
agg.mu.RUnlock()
assert.False(t, exists, "account should be evicted after first flush")
// Second flush: with delta temporality, the stale account should NOT reappear
p95sSecond := agg.FlushAndGetP95s()
assert.Empty(t, p95sSecond,
"evicted account should not reappear in subsequent flushes with delta temporality")
}
func TestP95Calculation_SingleSample(t *testing.T) {
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
agg.Record("account-A", 150*time.Millisecond)
p95s := agg.FlushAndGetP95s()
require.Len(t, p95s, 1)
// With a single sample, P95 should be the bucket bound containing 150ms
assert.Greater(t, p95s[0], int64(0), "P95 of a single sample should be positive")
}
func TestP95Calculation_AllSameValue(t *testing.T) {
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// All samples are 100ms — P95 should be the bucket bound containing 100ms
for range 100 {
agg.Record("account-A", 100*time.Millisecond)
}
p95s := agg.FlushAndGetP95s()
require.Len(t, p95s, 1)
assert.Greater(t, p95s[0], int64(0))
}
func TestMultipleAccounts_IndependentP95s(t *testing.T) {
ctx := context.Background()
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
defer func(agg *AccountDurationAggregator) {
err := agg.Shutdown()
if err != nil {
t.Errorf("failed to shutdown aggregator: %v", err)
}
}(agg)
// Account A: all fast (10ms)
for range 100 {
agg.Record("account-fast", 10*time.Millisecond)
}
// Account B: all slow (5000ms)
for range 100 {
agg.Record("account-slow", 5000*time.Millisecond)
}
p95s := agg.FlushAndGetP95s()
require.Len(t, p95s, 2, "should have two P95 values")
// Find min and max — they should differ significantly
minP95 := p95s[0]
maxP95 := p95s[1]
if minP95 > maxP95 {
minP95, maxP95 = maxP95, minP95
}
assert.Less(t, minP95, int64(1000),
"fast account P95 should be well under 1000ms")
assert.Greater(t, maxP95, int64(1000),
"slow account P95 should be well over 1000ms")
}

View File

@@ -13,18 +13,24 @@ const HighLatencyThreshold = time.Second * 7
// GRPCMetrics are gRPC server metrics
type GRPCMetrics struct {
meter metric.Meter
syncRequestsCounter metric.Int64Counter
syncRequestsBlockedCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter
loginRequestsBlockedCounter metric.Int64Counter
loginRequestHighLatencyCounter metric.Int64Counter
getKeyRequestsCounter metric.Int64Counter
activeStreamsGauge metric.Int64ObservableGauge
syncRequestDuration metric.Int64Histogram
loginRequestDuration metric.Int64Histogram
channelQueueLength metric.Int64Histogram
ctx context.Context
meter metric.Meter
syncRequestsCounter metric.Int64Counter
syncRequestsBlockedCounter metric.Int64Counter
loginRequestsCounter metric.Int64Counter
loginRequestsBlockedCounter metric.Int64Counter
loginRequestHighLatencyCounter metric.Int64Counter
getKeyRequestsCounter metric.Int64Counter
activeStreamsGauge metric.Int64ObservableGauge
syncRequestDuration metric.Int64Histogram
syncRequestDurationP95ByAccount metric.Int64Histogram
loginRequestDuration metric.Int64Histogram
loginRequestDurationP95ByAccount metric.Int64Histogram
channelQueueLength metric.Int64Histogram
ctx context.Context
// Per-account aggregation
syncDurationAggregator *AccountDurationAggregator
loginDurationAggregator *AccountDurationAggregator
}
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
@@ -93,6 +99,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
syncRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.sync.request.duration.p95.by.account.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("P95 duration of sync requests aggregated per account - each data point represents one account's P95"),
)
if err != nil {
return nil, err
}
loginRequestDuration, err := meter.Int64Histogram("management.grpc.login.request.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("Duration of the login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
@@ -101,6 +115,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
loginRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.login.request.duration.p95.by.account.ms",
metric.WithUnit("milliseconds"),
metric.WithDescription("P95 duration of login requests aggregated per account - each data point represents one account's P95"),
)
if err != nil {
return nil, err
}
// We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time
// Then we should be able to extract min, manx, mean and the percentiles.
// TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100)
@@ -113,20 +135,32 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
return nil, err
}
return &GRPCMetrics{
meter: meter,
syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
loginRequestsCounter: loginRequestsCounter,
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
getKeyRequestsCounter: getKeyRequestsCounter,
activeStreamsGauge: activeStreamsGauge,
syncRequestDuration: syncRequestDuration,
loginRequestDuration: loginRequestDuration,
channelQueueLength: channelQueue,
ctx: ctx,
}, err
syncDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
loginDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
grpcMetrics := &GRPCMetrics{
meter: meter,
syncRequestsCounter: syncRequestsCounter,
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
loginRequestsCounter: loginRequestsCounter,
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
getKeyRequestsCounter: getKeyRequestsCounter,
activeStreamsGauge: activeStreamsGauge,
syncRequestDuration: syncRequestDuration,
syncRequestDurationP95ByAccount: syncRequestDurationP95ByAccount,
loginRequestDuration: loginRequestDuration,
loginRequestDurationP95ByAccount: loginRequestDurationP95ByAccount,
channelQueueLength: channelQueue,
ctx: ctx,
syncDurationAggregator: syncDurationAggregator,
loginDurationAggregator: loginDurationAggregator,
}
go grpcMetrics.startSyncP95Flusher()
go grpcMetrics.startLoginP95Flusher()
return grpcMetrics, err
}
// CountSyncRequest counts the number of gRPC sync requests coming to the gRPC API
@@ -157,6 +191,9 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() {
// CountLoginRequestDuration counts the duration of the login gRPC requests
func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
grpcMetrics.loginDurationAggregator.Record(accountID, duration)
if duration > HighLatencyThreshold {
grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
}
@@ -165,6 +202,44 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
// CountSyncRequestDuration counts the duration of the sync gRPC requests
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
grpcMetrics.syncDurationAggregator.Record(accountID, duration)
}
// startSyncP95Flusher periodically flushes per-account sync P95 values to the histogram
func (grpcMetrics *GRPCMetrics) startSyncP95Flusher() {
ticker := time.NewTicker(grpcMetrics.syncDurationAggregator.FlushInterval)
defer ticker.Stop()
for {
select {
case <-grpcMetrics.ctx.Done():
return
case <-ticker.C:
p95s := grpcMetrics.syncDurationAggregator.FlushAndGetP95s()
for _, p95 := range p95s {
grpcMetrics.syncRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
}
}
}
}
// startLoginP95Flusher periodically flushes per-account login P95 values to the histogram
func (grpcMetrics *GRPCMetrics) startLoginP95Flusher() {
ticker := time.NewTicker(grpcMetrics.loginDurationAggregator.FlushInterval)
defer ticker.Stop()
for {
select {
case <-grpcMetrics.ctx.Done():
return
case <-ticker.C:
p95s := grpcMetrics.loginDurationAggregator.FlushAndGetP95s()
for _, p95 := range p95s {
grpcMetrics.loginRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
}
}
}
}
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.

View File

@@ -18,7 +18,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -100,7 +100,7 @@ type Account struct {
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"`
Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
@@ -906,7 +906,7 @@ func (a *Account) Copy() *Account {
networkResources = append(networkResources, resource.Copy())
}
services := []*reverseproxy.Service{}
services := []*service.Service{}
for _, service := range a.Services {
services = append(services, service.Copy())
}
@@ -1814,7 +1814,7 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
}
}
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
for _, target := range service.Targets {
if !target.Enabled {
continue
@@ -1823,7 +1823,7 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *rever
}
}
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) {
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
port, ok := a.resolveTargetPort(ctx, target)
if !ok {
return
@@ -1840,7 +1840,7 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *revers
}
}
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) {
func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) {
if target.Port != 0 {
return target.Port, true
}
@@ -1856,7 +1856,7 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Ta
}
}
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
return &Policy{
ID: policyID,

View File

@@ -47,6 +47,11 @@ type Settings struct {
// NetworkRange is the custom network range for that account
NetworkRange netip.Prefix `gorm:"serializer:json"`
// PeerExposeEnabled enables or disables peer-initiated service expose
PeerExposeEnabled bool
// PeerExposeGroups list of peer group IDs allowed to expose services
PeerExposeGroups []string `gorm:"serializer:json"`
// Extra is a dictionary of Account settings
Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
@@ -80,6 +85,8 @@ func (s *Settings) Copy() *Settings {
PeerInactivityExpiration: s.PeerInactivityExpiration,
RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled,
PeerExposeEnabled: s.PeerExposeEnabled,
PeerExposeGroups: slices.Clone(s.PeerExposeGroups),
LazyConnectionEnabled: s.LazyConnectionEnabled,
DNSDomain: s.DNSDomain,
NetworkRange: s.NetworkRange,

View File

@@ -742,6 +742,11 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
if err != nil {
return false, nil, nil, nil, fmt.Errorf("failed to re-read initiator user in transaction: %w", err)
}
// Ensure the initiator still has admin privileges
if initiatorUser.HasAdminPower() && !freshInitiator.HasAdminPower() {
return false, nil, nil, nil, status.Errorf(status.PermissionDenied, "initiator role was changed during request processing")
}
initiatorUser = freshInitiator
}
@@ -872,10 +877,6 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse
return nil
}
if !initiatorUser.HasAdminPower() {
return status.Errorf(status.PermissionDenied, "only admins and owners can update users")
}
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}

View File

@@ -2032,27 +2032,6 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {
})
}
func TestValidateUserUpdate_RejectsNonAdminInitiator(t *testing.T) {
groupsMap := map[string]*types.Group{}
initiator := &types.User{
Id: "initiator",
Role: types.UserRoleUser,
}
oldUser := &types.User{
Id: "target",
Role: types.UserRoleUser,
}
update := &types.User{
Id: "target",
Role: types.UserRoleOwner,
}
err := validateUserUpdate(groupsMap, initiator, oldUser, update)
require.Error(t, err, "regular user should not be able to promote to owner")
assert.Contains(t, err.Error(), "only admins and owners can update users")
}
func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) {
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
@@ -2109,7 +2088,7 @@ func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) {
})
require.Error(t, err, "processUserUpdate should reject stale initiator whose role was demoted")
assert.Contains(t, err.Error(), "only admins and owners can update users")
assert.Contains(t, err.Error(), "initiator role was changed during request processing")
targetUser, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetID)
require.NoError(t, err)