mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-14 12:49:57 +00:00
Merge branch 'refs/heads/main' into refactor/permissions-manager
# Conflicts: # management/internals/modules/reverseproxy/service/manager/api.go # management/server/http/handler.go
This commit is contained in:
@@ -15,7 +15,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
@@ -83,9 +83,9 @@ type DefaultAccountManager struct {
|
||||
|
||||
requestBuffer *AccountRequestBuffer
|
||||
|
||||
proxyController port_forwarding.Controller
|
||||
settingsManager settings.Manager
|
||||
reverseProxyManager reverseproxy.Manager
|
||||
proxyController port_forwarding.Controller
|
||||
settingsManager settings.Manager
|
||||
serviceManager service.Manager
|
||||
|
||||
// config contains the management server configuration
|
||||
config *nbconfig.Config
|
||||
@@ -115,8 +115,8 @@ type DefaultAccountManager struct {
|
||||
|
||||
var _ account.Manager = (*DefaultAccountManager)(nil)
|
||||
|
||||
func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
||||
am.reverseProxyManager = serviceManager
|
||||
func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) {
|
||||
am.serviceManager = serviceManager
|
||||
}
|
||||
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
@@ -376,6 +376,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -394,7 +395,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
|
||||
}
|
||||
if reloadReverseProxy {
|
||||
if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
@@ -492,6 +493,21 @@ func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Con
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePeerExposeSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||
oldEnabled := oldSettings.PeerExposeEnabled
|
||||
newEnabled := newSettings.PeerExposeEnabled
|
||||
|
||||
if oldEnabled == newEnabled {
|
||||
return
|
||||
}
|
||||
|
||||
event := activity.AccountPeerExposeEnabled
|
||||
if !newEnabled {
|
||||
event = activity.AccountPeerExposeDisabled
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
|
||||
if newSettings.PeerInactivityExpirationEnabled {
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
@@ -714,7 +730,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
err = am.reverseProxyManager.DeleteAllServices(ctx, accountID, userID)
|
||||
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
|
||||
}
|
||||
@@ -1363,9 +1379,10 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
if am.singleAccountMode && am.singleAccountModeDomain != "" {
|
||||
// This section is mostly related to self-hosted installations.
|
||||
// We override incoming domain claims to group users under a single account.
|
||||
userAuth.Domain = am.singleAccountModeDomain
|
||||
userAuth.DomainCategory = types.PrivateCategory
|
||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||
err := am.updateUserAuthWithSingleMode(ctx, &userAuth)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
|
||||
accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, userAuth)
|
||||
@@ -1398,6 +1415,35 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
|
||||
return accountID, user.Id, nil
|
||||
}
|
||||
|
||||
// updateUserAuthWithSingleMode modifies the userAuth with the single account domain, or if there is an existing account, with the domain of that account
|
||||
func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Context, userAuth *auth.UserAuth) error {
|
||||
userAuth.DomainCategory = types.PrivateCategory
|
||||
userAuth.Domain = am.singleAccountModeDomain
|
||||
|
||||
accountID, err := am.Store.GetAnyAccountID(ctx)
|
||||
if err != nil {
|
||||
if e, ok := status.FromError(err); !ok || e.Type() != status.NotFound {
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
|
||||
return nil
|
||||
}
|
||||
|
||||
if accountID == "" {
|
||||
log.WithContext(ctx).Debugf("using singleAccountModeDomain to override JWT Domain and DomainCategory claims in single account mode")
|
||||
return nil
|
||||
}
|
||||
|
||||
domain, _, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userAuth.Domain = domain
|
||||
|
||||
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package account
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package account -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -61,11 +63,11 @@ type Manager interface {
|
||||
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
|
||||
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
|
||||
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
|
||||
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
|
||||
@@ -140,5 +142,5 @@ type Manager interface {
|
||||
CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error
|
||||
GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error)
|
||||
GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error)
|
||||
SetServiceManager(serviceManager reverseproxy.Manager)
|
||||
SetServiceManager(serviceManager service.Manager)
|
||||
}
|
||||
|
||||
1738
management/server/account/manager_mock.go
Normal file
1738
management/server/account/manager_mock.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -15,10 +15,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/prometheus/client_golang/prometheus/push"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -27,8 +29,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
@@ -1803,12 +1807,12 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Address: "172.12.6.1/24",
|
||||
},
|
||||
},
|
||||
Services: []*reverseproxy.Service{
|
||||
Services: []*service.Service{
|
||||
{
|
||||
ID: "service1",
|
||||
Name: "test-service",
|
||||
AccountID: "account1",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
Targets: []*service.Target{},
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
@@ -3113,6 +3117,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
proxyManager := proxy.NewMockManager(ctrl)
|
||||
proxyManager.EXPECT().
|
||||
CleanupStale(gomock.Any(), gomock.Any()).
|
||||
Return(nil).
|
||||
AnyTimes()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
updateManager := update_channel.NewPeersUpdateManager(metrics)
|
||||
@@ -3123,8 +3133,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil)
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyGrpcServer, nil))
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager)
|
||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil))
|
||||
|
||||
return manager, updateManager, nil
|
||||
}
|
||||
@@ -3953,3 +3967,116 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
|
||||
t.Fatal("UpdateAccountSettings deadlocked when changing NetworkRange")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateUserAuthWithSingleMode(t *testing.T) {
|
||||
t.Run("sets defaults and overrides domain from store", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetAnyAccountID(gomock.Any()).
|
||||
Return("account-1", nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
|
||||
Return("real-domain.com", "private", nil)
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: mockStore,
|
||||
singleAccountModeDomain: "fallback.com",
|
||||
}
|
||||
|
||||
userAuth := &auth.UserAuth{}
|
||||
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "real-domain.com", userAuth.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||
})
|
||||
|
||||
t.Run("falls back to singleAccountModeDomain when account ID is empty", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetAnyAccountID(gomock.Any()).
|
||||
Return("", nil)
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: mockStore,
|
||||
singleAccountModeDomain: "fallback.com",
|
||||
}
|
||||
|
||||
userAuth := &auth.UserAuth{}
|
||||
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "fallback.com", userAuth.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||
})
|
||||
|
||||
t.Run("falls back to singleAccountModeDomain on NotFound error", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetAnyAccountID(gomock.Any()).
|
||||
Return("", status.Errorf(status.NotFound, "no accounts"))
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: mockStore,
|
||||
singleAccountModeDomain: "fallback.com",
|
||||
}
|
||||
|
||||
userAuth := &auth.UserAuth{}
|
||||
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "fallback.com", userAuth.Domain)
|
||||
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||
})
|
||||
|
||||
t.Run("propagates non-NotFound error from GetAnyAccountID", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetAnyAccountID(gomock.Any()).
|
||||
Return("", status.Errorf(status.Internal, "db down"))
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: mockStore,
|
||||
singleAccountModeDomain: "fallback.com",
|
||||
}
|
||||
|
||||
userAuth := &auth.UserAuth{}
|
||||
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "db down")
|
||||
// Defaults should still be set before error path
|
||||
assert.Equal(t, types.PrivateCategory, userAuth.DomainCategory)
|
||||
})
|
||||
|
||||
t.Run("propagates error from GetAccountDomainAndCategory", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetAnyAccountID(gomock.Any()).
|
||||
Return("account-1", nil)
|
||||
mockStore.EXPECT().
|
||||
GetAccountDomainAndCategory(gomock.Any(), store.LockingStrengthNone, "account-1").
|
||||
Return("", "", status.Errorf(status.Internal, "query failed"))
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: mockStore,
|
||||
singleAccountModeDomain: "fallback.com",
|
||||
}
|
||||
|
||||
userAuth := &auth.UserAuth{}
|
||||
err := am.updateUserAuthWithSingleMode(context.Background(), userAuth)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "query failed")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -208,6 +208,18 @@ const (
|
||||
ServiceUpdated Activity = 109
|
||||
ServiceDeleted Activity = 110
|
||||
|
||||
// PeerServiceExposed indicates that a peer exposed a service via the reverse proxy
|
||||
PeerServiceExposed Activity = 111
|
||||
// PeerServiceUnexposed indicates that a peer-exposed service was removed
|
||||
PeerServiceUnexposed Activity = 112
|
||||
// PeerServiceExposeExpired indicates that a peer-exposed service was removed due to TTL expiration
|
||||
PeerServiceExposeExpired Activity = 113
|
||||
|
||||
// AccountPeerExposeEnabled indicates that a user enabled peer expose for the account
|
||||
AccountPeerExposeEnabled Activity = 114
|
||||
// AccountPeerExposeDisabled indicates that a user disabled peer expose for the account
|
||||
AccountPeerExposeDisabled Activity = 115
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -345,6 +357,13 @@ var activityMap = map[Activity]Code{
|
||||
ServiceCreated: {"Service created", "service.create"},
|
||||
ServiceUpdated: {"Service updated", "service.update"},
|
||||
ServiceDeleted: {"Service deleted", "service.delete"},
|
||||
|
||||
PeerServiceExposed: {"Peer exposed service", "service.peer.expose"},
|
||||
PeerServiceUnexposed: {"Peer unexposed service", "service.peer.unexpose"},
|
||||
PeerServiceExposeExpired: {"Peer exposed service expired", "service.peer.expose.expire"},
|
||||
|
||||
AccountPeerExposeEnabled: {"Account peer expose enabled", "account.setting.peer.expose.enable"},
|
||||
AccountPeerExposeDisabled: {"Account peer expose disabled", "account.setting.peer.expose.disable"},
|
||||
}
|
||||
|
||||
// StringCode returns a string code of the activity
|
||||
|
||||
@@ -249,7 +249,15 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
|
||||
|
||||
switch storeEngine {
|
||||
case types.SqliteStoreEngine:
|
||||
dialector = sqlite.Open(filepath.Join(dataDir, eventSinkDB))
|
||||
dbFile := eventSinkDB
|
||||
if envFile, ok := os.LookupEnv("NB_ACTIVITY_EVENT_SQLITE_FILE"); ok && envFile != "" {
|
||||
dbFile = envFile
|
||||
}
|
||||
connStr := dbFile
|
||||
if !filepath.IsAbs(dbFile) {
|
||||
connStr = filepath.Join(dataDir, dbFile)
|
||||
}
|
||||
dialector = sqlite.Open(connStr)
|
||||
case types.PostgresStoreEngine:
|
||||
dsn, ok := os.LookupEnv(postgresDsnEnv)
|
||||
if !ok {
|
||||
|
||||
@@ -425,6 +425,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
var groupIDsToDelete []string
|
||||
var deletedGroups []*types.Group
|
||||
|
||||
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
|
||||
@@ -433,7 +438,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
continue
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
|
||||
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
|
||||
allErrors = errors.Join(allErrors, err)
|
||||
continue
|
||||
}
|
||||
@@ -621,7 +626,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
|
||||
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == types.GroupIssuedIntegration {
|
||||
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
@@ -641,6 +646,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
|
||||
return &GroupLinkError{"network resource", group.Resources[0].ID}
|
||||
}
|
||||
|
||||
if slices.Contains(flowGroups, group.ID) {
|
||||
return &GroupLinkError{"settings", "traffic event logging"}
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
peer2 "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -284,6 +286,67 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_DeleteGroupLinkedToFlowGroup(t *testing.T) {
|
||||
am, _, err := createManager(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
settingsMock := settings.NewMockManager(ctrl)
|
||||
settingsMock.EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{FlowGroups: []string{"grp-for-flow"}}, nil).
|
||||
AnyTimes()
|
||||
settingsMock.EXPECT().
|
||||
UpdateExtraSettings(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
Return(false, nil).
|
||||
AnyTimes()
|
||||
am.settingsManager = settingsMock
|
||||
|
||||
_, account, err := initTestGroupAccount(am)
|
||||
require.NoError(t, err)
|
||||
|
||||
grp := &types.Group{
|
||||
ID: "grp-for-flow",
|
||||
AccountID: account.Id,
|
||||
Name: "Group for flow",
|
||||
Issued: types.GroupIssuedAPI,
|
||||
Peers: make([]string, 0),
|
||||
}
|
||||
require.NoError(t, am.CreateGroup(context.Background(), account.Id, groupAdminUserID, grp))
|
||||
|
||||
err = am.DeleteGroup(context.Background(), account.Id, groupAdminUserID, "grp-for-flow")
|
||||
require.Error(t, err)
|
||||
|
||||
var gErr *GroupLinkError
|
||||
require.ErrorAs(t, err, &gErr)
|
||||
assert.Equal(t, "settings", gErr.Resource)
|
||||
assert.Equal(t, "traffic event logging", gErr.Name)
|
||||
|
||||
group, err := am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, group)
|
||||
|
||||
regularGrp := &types.Group{
|
||||
ID: "grp-regular",
|
||||
AccountID: account.Id,
|
||||
Name: "Regular group",
|
||||
Issued: types.GroupIssuedAPI,
|
||||
Peers: make([]string, 0),
|
||||
}
|
||||
err = am.CreateGroup(context.Background(), account.Id, groupAdminUserID, regularGrp)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, []string{"grp-for-flow", "grp-regular"})
|
||||
require.Error(t, err)
|
||||
|
||||
group, err = am.GetGroup(context.Background(), account.Id, "grp-for-flow", groupAdminUserID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, group)
|
||||
|
||||
_, err = am.GetGroup(context.Background(), account.Id, "grp-regular", groupAdminUserID)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) {
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
@@ -703,7 +766,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
|
||||
t.Run("saving group linked to network router", func(t *testing.T) {
|
||||
permissionsManager := permissions.NewManager(manager.Store)
|
||||
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager)
|
||||
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
|
||||
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
|
||||
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
||||
@@ -73,7 +73,7 @@ const (
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
|
||||
|
||||
// Register bypass paths for unauthenticated endpoints
|
||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||
@@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
idp.AddEndpoints(accountManager, router, permissionsManager)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router, permissionsManager)
|
||||
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
||||
if serviceManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
||||
}
|
||||
|
||||
// Register OAuth callback handler for proxy authentication
|
||||
|
||||
@@ -163,6 +163,10 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request, userAut
|
||||
}
|
||||
|
||||
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
|
||||
if req.Settings.PeerExposeEnabled && len(req.Settings.PeerExposeGroups) == 0 {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer expose requires at least one group")
|
||||
}
|
||||
|
||||
returnSettings := &types.Settings{
|
||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||
@@ -170,6 +174,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
|
||||
|
||||
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
||||
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
||||
|
||||
PeerExposeEnabled: req.Settings.PeerExposeEnabled,
|
||||
PeerExposeGroups: req.Settings.PeerExposeGroups,
|
||||
}
|
||||
|
||||
if req.Settings.Extra != nil {
|
||||
@@ -317,6 +324,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
JwtAllowGroups: &jwtAllowGroups,
|
||||
RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
|
||||
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
|
||||
PeerExposeEnabled: settings.PeerExposeEnabled,
|
||||
PeerExposeGroups: settings.PeerExposeGroups,
|
||||
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
||||
DnsDomain: &settings.DNSDomain,
|
||||
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
||||
|
||||
@@ -18,8 +18,8 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -190,7 +190,8 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
|
||||
oidcServer := newFakeOIDCServer()
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute)
|
||||
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
|
||||
require.NoError(t, err)
|
||||
|
||||
usersManager := users.NewManager(testStore)
|
||||
|
||||
@@ -208,9 +209,10 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
proxyService.SetProxyManager(&testServiceManager{store: testStore})
|
||||
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||
|
||||
handler := NewAuthCallbackHandler(proxyService, nil)
|
||||
|
||||
@@ -239,12 +241,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
pubKey := base64.StdEncoding.EncodeToString(pub)
|
||||
privKey := base64.StdEncoding.EncodeToString(priv)
|
||||
|
||||
testProxy := &reverseproxy.Service{
|
||||
testProxy := &service.Service{
|
||||
ID: "testProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Test Proxy",
|
||||
Domain: "test-proxy.example.com",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Targets: []*service.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
@@ -254,8 +256,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
Enabled: true,
|
||||
}},
|
||||
Enabled: true,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"allowedGroupId"},
|
||||
},
|
||||
@@ -265,12 +267,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
}
|
||||
require.NoError(t, testStore.CreateService(ctx, testProxy))
|
||||
|
||||
restrictedProxy := &reverseproxy.Service{
|
||||
restrictedProxy := &service.Service{
|
||||
ID: "restrictedProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "Restricted Proxy",
|
||||
Domain: "restricted-proxy.example.com",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Targets: []*service.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
@@ -280,8 +282,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
Enabled: true,
|
||||
}},
|
||||
Enabled: true,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: []string{"restrictedGroupId"},
|
||||
},
|
||||
@@ -291,12 +293,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
}
|
||||
require.NoError(t, testStore.CreateService(ctx, restrictedProxy))
|
||||
|
||||
noAuthProxy := &reverseproxy.Service{
|
||||
noAuthProxy := &service.Service{
|
||||
ID: "noAuthProxyId",
|
||||
AccountID: "testAccountId",
|
||||
Name: "No Auth Proxy",
|
||||
Domain: "no-auth-proxy.example.com",
|
||||
Targets: []*reverseproxy.Target{{
|
||||
Targets: []*service.Target{{
|
||||
Path: strPtr("/"),
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
@@ -306,8 +308,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store
|
||||
Enabled: true,
|
||||
}},
|
||||
Enabled: true,
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
BearerAuth: &reverseproxy.BearerAuthConfig{
|
||||
Auth: service.AuthConfig{
|
||||
BearerAuth: &service.BearerAuthConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
@@ -361,19 +363,19 @@ func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, u
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -385,7 +387,7 @@ func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ stri
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error {
|
||||
func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -397,15 +399,15 @@ func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||
return m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) {
|
||||
return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID)
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
@@ -413,6 +415,20 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
|
||||
|
||||
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -9,10 +9,13 @@ import (
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
|
||||
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
|
||||
@@ -91,12 +94,24 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
}
|
||||
|
||||
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
|
||||
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager)
|
||||
domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager)
|
||||
reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, proxyServiceServer, domainManager)
|
||||
proxyServiceServer.SetProxyManager(reverseProxyManager)
|
||||
am.SetServiceManager(reverseProxyManager)
|
||||
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy token store: %v", err)
|
||||
}
|
||||
noopMeter := noop.NewMeterProvider().Meter("")
|
||||
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy controller: %v", err)
|
||||
}
|
||||
serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager)
|
||||
proxyServiceServer.SetServiceManager(serviceManager)
|
||||
am.SetServiceManager(serviceManager)
|
||||
|
||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||
authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
@@ -114,7 +129,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil)
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ type EmbeddedIdPConfig struct {
|
||||
|
||||
// EmbeddedStorageConfig holds storage configuration for the embedded IdP.
|
||||
type EmbeddedStorageConfig struct {
|
||||
// Type is the storage type (currently only "sqlite3" is supported)
|
||||
// Type is the storage type: "sqlite3" (default) or "postgres"
|
||||
Type string
|
||||
// Config contains type-specific configuration
|
||||
Config EmbeddedStorageTypeConfig
|
||||
@@ -62,6 +62,8 @@ type EmbeddedStorageConfig struct {
|
||||
type EmbeddedStorageTypeConfig struct {
|
||||
// File is the path to the SQLite database file (for sqlite3 type)
|
||||
File string
|
||||
// DSN is the connection string for postgres
|
||||
DSN string
|
||||
}
|
||||
|
||||
// OwnerConfig represents the initial owner/admin user for the embedded IdP.
|
||||
@@ -74,6 +76,22 @@ type OwnerConfig struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
// buildIdpStorageConfig builds the Dex storage config map based on the storage type.
|
||||
func buildIdpStorageConfig(storageType string, cfg EmbeddedStorageTypeConfig) (map[string]interface{}, error) {
|
||||
switch storageType {
|
||||
case "sqlite3":
|
||||
return map[string]interface{}{
|
||||
"file": cfg.File,
|
||||
}, nil
|
||||
case "postgres":
|
||||
return map[string]interface{}{
|
||||
"dsn": cfg.DSN,
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported IdP storage type: %s", storageType)
|
||||
}
|
||||
}
|
||||
|
||||
// ToYAMLConfig converts EmbeddedIdPConfig to dex.YAMLConfig.
|
||||
func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
if c.Issuer == "" {
|
||||
@@ -85,6 +103,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
if c.Storage.Type == "sqlite3" && c.Storage.Config.File == "" {
|
||||
return nil, fmt.Errorf("storage file is required for sqlite3")
|
||||
}
|
||||
if c.Storage.Type == "postgres" && c.Storage.Config.DSN == "" {
|
||||
return nil, fmt.Errorf("storage DSN is required for postgres")
|
||||
}
|
||||
|
||||
storageConfig, err := buildIdpStorageConfig(c.Storage.Type, c.Storage.Config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
|
||||
}
|
||||
|
||||
// Build CLI redirect URIs including the device callback (both relative and absolute)
|
||||
cliRedirectURIs := c.CLIRedirectURIs
|
||||
@@ -100,10 +126,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
|
||||
cfg := &dex.YAMLConfig{
|
||||
Issuer: c.Issuer,
|
||||
Storage: dex.Storage{
|
||||
Type: c.Storage.Type,
|
||||
Config: map[string]interface{}{
|
||||
"file": c.Storage.Config.File,
|
||||
},
|
||||
Type: c.Storage.Type,
|
||||
Config: storageConfig,
|
||||
},
|
||||
Web: dex.Web{
|
||||
AllowedOrigins: []string{"*"},
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -51,6 +52,7 @@ type properties map[string]interface{}
|
||||
type DataSource interface {
|
||||
GetAllAccounts(ctx context.Context) []*types.Account
|
||||
GetStoreEngine() types.Engine
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
}
|
||||
|
||||
// ConnManager peer connection manager that holds state for current active connections
|
||||
@@ -211,6 +213,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
localUsers int
|
||||
idpUsers int
|
||||
embeddedIdpTypes map[string]int
|
||||
services int
|
||||
servicesEnabled int
|
||||
servicesTargets int
|
||||
servicesStatusActive int
|
||||
servicesStatusPending int
|
||||
servicesStatusError int
|
||||
servicesTargetType map[string]int
|
||||
servicesAuthPassword int
|
||||
servicesAuthPin int
|
||||
servicesAuthOIDC int
|
||||
)
|
||||
start := time.Now()
|
||||
metricsProperties := make(properties)
|
||||
@@ -220,10 +232,13 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
rulesDirection = make(map[string]int)
|
||||
activeUsersLastDay = make(map[string]struct{})
|
||||
embeddedIdpTypes = make(map[string]int)
|
||||
servicesTargetType = make(map[string]int)
|
||||
uptime = time.Since(w.startupTime).Seconds()
|
||||
connections := w.connManager.GetAllConnectedPeers()
|
||||
version = nbversion.NetbirdVersion()
|
||||
|
||||
customDomains, customDomainsValidated, _ := w.dataSource.GetCustomDomainsCounts(ctx)
|
||||
|
||||
for _, account := range w.dataSource.GetAllAccounts(ctx) {
|
||||
accounts++
|
||||
|
||||
@@ -279,9 +294,9 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
localUsers++
|
||||
} else {
|
||||
idpUsers++
|
||||
idpType := extractIdpType(idpID)
|
||||
embeddedIdpTypes[idpType]++
|
||||
}
|
||||
idpType := extractIdpType(idpID)
|
||||
embeddedIdpTypes[idpType]++
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -335,6 +350,37 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
peerActiveVersions = append(peerActiveVersions, peer.Meta.WtVersion)
|
||||
}
|
||||
}
|
||||
|
||||
for _, service := range account.Services {
|
||||
services++
|
||||
if service.Enabled {
|
||||
servicesEnabled++
|
||||
}
|
||||
servicesTargets += len(service.Targets)
|
||||
|
||||
switch rpservice.Status(service.Meta.Status) {
|
||||
case rpservice.StatusActive:
|
||||
servicesStatusActive++
|
||||
case rpservice.StatusPending:
|
||||
servicesStatusPending++
|
||||
case rpservice.StatusError, rpservice.StatusCertificateFailed, rpservice.StatusTunnelNotCreated:
|
||||
servicesStatusError++
|
||||
}
|
||||
|
||||
for _, target := range service.Targets {
|
||||
servicesTargetType[target.TargetType]++
|
||||
}
|
||||
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled {
|
||||
servicesAuthPassword++
|
||||
}
|
||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled {
|
||||
servicesAuthPin++
|
||||
}
|
||||
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
|
||||
servicesAuthOIDC++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
minActivePeerVersion, maxActivePeerVersion := getMinMaxVersion(peerActiveVersions)
|
||||
@@ -375,6 +421,22 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
metricsProperties["idp_users_count"] = idpUsers
|
||||
metricsProperties["embedded_idp_count"] = len(embeddedIdpTypes)
|
||||
|
||||
metricsProperties["services"] = services
|
||||
metricsProperties["services_enabled"] = servicesEnabled
|
||||
metricsProperties["services_targets"] = servicesTargets
|
||||
metricsProperties["services_status_active"] = servicesStatusActive
|
||||
metricsProperties["services_status_pending"] = servicesStatusPending
|
||||
metricsProperties["services_status_error"] = servicesStatusError
|
||||
metricsProperties["services_auth_password"] = servicesAuthPassword
|
||||
metricsProperties["services_auth_pin"] = servicesAuthPin
|
||||
metricsProperties["services_auth_oidc"] = servicesAuthOIDC
|
||||
metricsProperties["custom_domains"] = customDomains
|
||||
metricsProperties["custom_domains_validated"] = customDomainsValidated
|
||||
|
||||
for targetType, count := range servicesTargetType {
|
||||
metricsProperties["services_target_type_"+targetType] = count
|
||||
}
|
||||
|
||||
for idpType, count := range embeddedIdpTypes {
|
||||
metricsProperties["embedded_idp_users_"+idpType] = count
|
||||
}
|
||||
@@ -469,6 +531,9 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
|
||||
// Connector IDs are formatted as "<type>-<xid>" (e.g., "okta-abc123", "zitadel-xyz").
|
||||
// Returns the type prefix, or "oidc" if no known prefix is found.
|
||||
func extractIdpType(connectorID string) string {
|
||||
if connectorID == "local" {
|
||||
return "local"
|
||||
}
|
||||
idx := strings.LastIndex(connectorID, "-")
|
||||
if idx <= 0 {
|
||||
return "oidc"
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -28,6 +29,7 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
|
||||
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
localUserID := dex.EncodeDexUserID("10", "local")
|
||||
idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0")
|
||||
oidcUserID := dex.EncodeDexUserID("30", "d6jvvp69kmnc73c9pl40")
|
||||
return []*types.Account{
|
||||
{
|
||||
Id: "1",
|
||||
@@ -115,6 +117,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
},
|
||||
},
|
||||
},
|
||||
Services: []*rpservice.Service{
|
||||
{
|
||||
ID: "svc1",
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: "peer"},
|
||||
{TargetType: "host"},
|
||||
},
|
||||
Auth: rpservice.AuthConfig{
|
||||
PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true},
|
||||
},
|
||||
Meta: rpservice.Meta{Status: string(rpservice.StatusActive)},
|
||||
},
|
||||
{
|
||||
ID: "svc2",
|
||||
Enabled: false,
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: "domain"},
|
||||
},
|
||||
Auth: rpservice.AuthConfig{
|
||||
BearerAuth: &rpservice.BearerAuthConfig{Enabled: true},
|
||||
},
|
||||
Meta: rpservice.Meta{Status: string(rpservice.StatusPending)},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Id: "2",
|
||||
@@ -180,6 +207,13 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
"1": {},
|
||||
},
|
||||
},
|
||||
oidcUserID: {
|
||||
Id: oidcUserID,
|
||||
IsServiceUser: false,
|
||||
PATs: map[string]*types.PersonalAccessToken{
|
||||
"1": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
Networks: []*networkTypes.Network{
|
||||
{
|
||||
@@ -215,6 +249,11 @@ func (mockDatasource) GetStoreEngine() types.Engine {
|
||||
return types.FileStoreEngine
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts returns test custom domain counts.
|
||||
func (mockDatasource) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) {
|
||||
return 3, 2, nil
|
||||
}
|
||||
|
||||
// TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties
|
||||
func TestGenerateProperties(t *testing.T) {
|
||||
ds := mockDatasource{}
|
||||
@@ -247,14 +286,14 @@ func TestGenerateProperties(t *testing.T) {
|
||||
if properties["rules"] != 4 {
|
||||
t.Errorf("expected 4 rules, got %d", properties["rules"])
|
||||
}
|
||||
if properties["users"] != 2 {
|
||||
t.Errorf("expected 1 users, got %d", properties["users"])
|
||||
if properties["users"] != 3 {
|
||||
t.Errorf("expected 3 users, got %d", properties["users"])
|
||||
}
|
||||
if properties["setup_keys_usage"] != 2 {
|
||||
t.Errorf("expected 1 setup_keys_usage, got %d", properties["setup_keys_usage"])
|
||||
}
|
||||
if properties["pats"] != 4 {
|
||||
t.Errorf("expected 4 personal_access_tokens, got %d", properties["pats"])
|
||||
if properties["pats"] != 5 {
|
||||
t.Errorf("expected 5 personal_access_tokens, got %d", properties["pats"])
|
||||
}
|
||||
if properties["peers_ssh_enabled"] != 2 {
|
||||
t.Errorf("expected 2 peers_ssh_enabled, got %d", properties["peers_ssh_enabled"])
|
||||
@@ -338,14 +377,63 @@ func TestGenerateProperties(t *testing.T) {
|
||||
if properties["local_users_count"] != 1 {
|
||||
t.Errorf("expected 1 local_users_count, got %d", properties["local_users_count"])
|
||||
}
|
||||
if properties["idp_users_count"] != 1 {
|
||||
t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"])
|
||||
if properties["idp_users_count"] != 2 {
|
||||
t.Errorf("expected 2 idp_users_count, got %d", properties["idp_users_count"])
|
||||
}
|
||||
if properties["embedded_idp_users_local"] != 1 {
|
||||
t.Errorf("expected 1 embedded_idp_users_local, got %v", properties["embedded_idp_users_local"])
|
||||
}
|
||||
if properties["embedded_idp_users_zitadel"] != 1 {
|
||||
t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"])
|
||||
}
|
||||
if properties["embedded_idp_count"] != 1 {
|
||||
t.Errorf("expected 1 embedded_idp_count, got %v", properties["embedded_idp_count"])
|
||||
if properties["embedded_idp_users_oidc"] != 1 {
|
||||
t.Errorf("expected 1 embedded_idp_users_oidc, got %v", properties["embedded_idp_users_oidc"])
|
||||
}
|
||||
if properties["embedded_idp_count"] != 3 {
|
||||
t.Errorf("expected 3 embedded_idp_count, got %v", properties["embedded_idp_count"])
|
||||
}
|
||||
|
||||
if properties["services"] != 2 {
|
||||
t.Errorf("expected 2 services, got %v", properties["services"])
|
||||
}
|
||||
if properties["services_enabled"] != 1 {
|
||||
t.Errorf("expected 1 services_enabled, got %v", properties["services_enabled"])
|
||||
}
|
||||
if properties["services_targets"] != 3 {
|
||||
t.Errorf("expected 3 services_targets, got %v", properties["services_targets"])
|
||||
}
|
||||
if properties["services_status_active"] != 1 {
|
||||
t.Errorf("expected 1 services_status_active, got %v", properties["services_status_active"])
|
||||
}
|
||||
if properties["services_status_pending"] != 1 {
|
||||
t.Errorf("expected 1 services_status_pending, got %v", properties["services_status_pending"])
|
||||
}
|
||||
if properties["services_status_error"] != 0 {
|
||||
t.Errorf("expected 0 services_status_error, got %v", properties["services_status_error"])
|
||||
}
|
||||
if properties["services_target_type_peer"] != 1 {
|
||||
t.Errorf("expected 1 services_target_type_peer, got %v", properties["services_target_type_peer"])
|
||||
}
|
||||
if properties["services_target_type_host"] != 1 {
|
||||
t.Errorf("expected 1 services_target_type_host, got %v", properties["services_target_type_host"])
|
||||
}
|
||||
if properties["services_target_type_domain"] != 1 {
|
||||
t.Errorf("expected 1 services_target_type_domain, got %v", properties["services_target_type_domain"])
|
||||
}
|
||||
if properties["services_auth_password"] != 1 {
|
||||
t.Errorf("expected 1 services_auth_password, got %v", properties["services_auth_password"])
|
||||
}
|
||||
if properties["services_auth_oidc"] != 1 {
|
||||
t.Errorf("expected 1 services_auth_oidc, got %v", properties["services_auth_oidc"])
|
||||
}
|
||||
if properties["services_auth_pin"] != 0 {
|
||||
t.Errorf("expected 0 services_auth_pin, got %v", properties["services_auth_pin"])
|
||||
}
|
||||
if properties["custom_domains"] != int64(3) {
|
||||
t.Errorf("expected 3 custom_domains, got %v", properties["custom_domains"])
|
||||
}
|
||||
if properties["custom_domains_validated"] != int64(2) {
|
||||
t.Errorf("expected 2 custom_domains_validated, got %v", properties["custom_domains_validated"])
|
||||
}
|
||||
}
|
||||
|
||||
@@ -362,7 +450,8 @@ func TestExtractIdpType(t *testing.T) {
|
||||
{"microsoft-abc123", "microsoft"},
|
||||
{"authentik-abc123", "authentik"},
|
||||
{"keycloak-d5uv82dra0haedlf6kv0", "keycloak"},
|
||||
{"local", "oidc"},
|
||||
{"local", "local"},
|
||||
{"d6jvvp69kmnc73c9pl40", "oidc"},
|
||||
{"", "oidc"},
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
@@ -148,7 +148,7 @@ type MockAccountManager struct {
|
||||
DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) {
|
||||
func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
@@ -407,7 +407,7 @@ func (am *MockAccountManager) AddPeer(
|
||||
|
||||
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||
if am.GetGroupFunc != nil {
|
||||
if am.GetGroupByNameFunc != nil {
|
||||
return am.GetGroupByNameFunc(ctx, accountID, groupName)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
@@ -33,23 +33,23 @@ type Manager interface {
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
groupsManager groups.Manager
|
||||
accountManager account.Manager
|
||||
reverseProxyManager reverseproxy.Manager
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
groupsManager groups.Manager
|
||||
accountManager account.Manager
|
||||
serviceManager service.Manager
|
||||
}
|
||||
|
||||
type mockManager struct {
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager {
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
groupsManager: groupsManager,
|
||||
accountManager: accountManager,
|
||||
reverseProxyManager: reverseproxyManager,
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
groupsManager: groupsManager,
|
||||
accountManager: accountManager,
|
||||
serviceManager: reverseproxyManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
|
||||
// TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them
|
||||
go func() {
|
||||
err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
|
||||
err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err)
|
||||
}
|
||||
@@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
|
||||
serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if resource is used by service: %w", err)
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -31,8 +31,8 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
|
||||
require.NoError(t, err)
|
||||
@@ -54,8 +54,8 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID)
|
||||
require.Error(t, err)
|
||||
@@ -76,8 +76,8 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
|
||||
require.NoError(t, err)
|
||||
@@ -98,8 +98,8 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID)
|
||||
require.Error(t, err)
|
||||
@@ -123,8 +123,8 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.NoError(t, err)
|
||||
@@ -147,8 +147,8 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.Error(t, err)
|
||||
@@ -176,9 +176,9 @@ func Test_CreateResourceSuccessfully(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.NoError(t, err)
|
||||
@@ -205,8 +205,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -234,8 +234,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -262,8 +262,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
createdResource, err := manager.CreateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -294,9 +294,9 @@ func Test_UpdateResourceSuccessfully(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.NoError(t, err)
|
||||
@@ -329,8 +329,8 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -361,8 +361,8 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -392,8 +392,8 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
updatedResource, err := manager.UpdateResource(ctx, userID, resource)
|
||||
require.Error(t, err)
|
||||
@@ -416,9 +416,9 @@ func Test_DeleteResourceSuccessfully(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes()
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.NoError(t, err)
|
||||
@@ -440,8 +440,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) {
|
||||
am := mock_server.MockAccountManager{}
|
||||
groupsManager := groups.NewManagerMock()
|
||||
ctrl := gomock.NewController(t)
|
||||
reverseProxyManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager)
|
||||
serviceManager := reverseproxy.NewMockManager(ctrl)
|
||||
manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager)
|
||||
|
||||
err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID)
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -493,7 +493,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
var settings *types.Settings
|
||||
var eventsToStore []func()
|
||||
|
||||
serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID)
|
||||
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if resource is used by service: %w", err)
|
||||
}
|
||||
|
||||
@@ -352,9 +352,10 @@ func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest)
|
||||
p.Name = a.Name
|
||||
p.Key = a.WgPubKey
|
||||
p.Meta = PeerSystemMeta{
|
||||
Hostname: a.Name,
|
||||
GoOS: "js",
|
||||
OS: "js",
|
||||
Hostname: a.Name,
|
||||
GoOS: "js",
|
||||
OS: "js",
|
||||
KernelVersion: "wasm",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -269,3 +269,8 @@ func (s *FileStore) GetStoreEngine() types.Engine {
|
||||
func (s *FileStore) SetFieldEncrypt(_ *crypt.FieldEncrypt) {
|
||||
// no-op: FileStore stores data in plaintext JSON; encryption is not supported
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts is a no-op for FileStore as it doesn't support custom domains.
|
||||
func (s *FileStore) GetCustomDomainsCounts(_ context.Context) (int64, int64, error) {
|
||||
return 0, 0, nil
|
||||
}
|
||||
|
||||
@@ -28,9 +28,10 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -131,8 +132,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
|
||||
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
|
||||
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{},
|
||||
&types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{},
|
||||
&accesslogs.AccessLogEntry{}, &proxy.Proxy{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
|
||||
@@ -1007,6 +1008,18 @@ func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts returns the total and validated custom domain counts.
|
||||
func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
|
||||
var total, validated int64
|
||||
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Count(&total).Error; err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return total, validated, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
|
||||
var accounts []types.Account
|
||||
result := s.db.Find(&accounts)
|
||||
@@ -2063,7 +2076,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
|
||||
return checks, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
|
||||
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
||||
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key
|
||||
@@ -2078,8 +2091,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
return nil, err
|
||||
}
|
||||
|
||||
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) {
|
||||
var s reverseproxy.Service
|
||||
services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) {
|
||||
var s rpservice.Service
|
||||
var auth []byte
|
||||
var createdAt, certIssuedAt sql.NullTime
|
||||
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
||||
@@ -2109,12 +2122,13 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
}
|
||||
}
|
||||
|
||||
s.Meta = reverseproxy.ServiceMeta{}
|
||||
s.Meta = rpservice.Meta{}
|
||||
if createdAt.Valid {
|
||||
s.Meta.CreatedAt = createdAt.Time
|
||||
}
|
||||
if certIssuedAt.Valid {
|
||||
s.Meta.CertificateIssuedAt = certIssuedAt.Time
|
||||
t := certIssuedAt.Time
|
||||
s.Meta.CertificateIssuedAt = &t
|
||||
}
|
||||
if status.Valid {
|
||||
s.Meta.Status = status.String
|
||||
@@ -2129,7 +2143,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
s.SessionPublicKey = sessionPublicKey.String
|
||||
}
|
||||
|
||||
s.Targets = []*reverseproxy.Target{}
|
||||
s.Targets = []*rpservice.Target{}
|
||||
return &s, nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -2141,7 +2155,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
}
|
||||
|
||||
serviceIDs := make([]string, len(services))
|
||||
serviceMap := make(map[string]*reverseproxy.Service)
|
||||
serviceMap := make(map[string]*rpservice.Service)
|
||||
for i, s := range services {
|
||||
serviceIDs[i] = s.ID
|
||||
serviceMap[s.ID] = s
|
||||
@@ -2152,8 +2166,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) {
|
||||
var t reverseproxy.Target
|
||||
targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) {
|
||||
var t rpservice.Target
|
||||
var path sql.NullString
|
||||
err := row.Scan(
|
||||
&t.ID,
|
||||
@@ -2715,14 +2729,28 @@ func (s *SqlStore) GetStoreEngine() types.Engine {
|
||||
|
||||
// NewSqliteStore creates a new SQLite store.
|
||||
func NewSqliteStore(ctx context.Context, dataDir string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) {
|
||||
storeStr := fmt.Sprintf("%s?cache=shared", storeSqliteFileName)
|
||||
if runtime.GOOS == "windows" {
|
||||
// Vo avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
storeStr = storeSqliteFileName
|
||||
storeFile := storeSqliteFileName
|
||||
if envFile, ok := os.LookupEnv("NB_STORE_ENGINE_SQLITE_FILE"); ok && envFile != "" {
|
||||
storeFile = envFile
|
||||
}
|
||||
|
||||
file := filepath.Join(dataDir, storeStr)
|
||||
db, err := gorm.Open(sqlite.Open(file), getGormConfig())
|
||||
// Separate file path from any SQLite URI query parameters (e.g., "store.db?mode=rwc")
|
||||
filePath, query, hasQuery := strings.Cut(storeFile, "?")
|
||||
|
||||
connStr := filePath
|
||||
if !filepath.IsAbs(filePath) {
|
||||
connStr = filepath.Join(dataDir, filePath)
|
||||
}
|
||||
|
||||
// Append query parameters: user-provided take precedence, otherwise default to cache=shared on non-Windows
|
||||
if hasQuery {
|
||||
connStr += "?" + query
|
||||
} else if runtime.GOOS != "windows" {
|
||||
// To avoid `The process cannot access the file because it is being used by another process` on Windows
|
||||
connStr += "?cache=shared"
|
||||
}
|
||||
|
||||
db, err := gorm.Open(sqlite.Open(connStr), getGormConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -4825,7 +4853,7 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren
|
||||
return peerID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error {
|
||||
serviceCopy := service.Copy()
|
||||
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt service data: %w", err)
|
||||
@@ -4839,16 +4867,19 @@ func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Serv
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error {
|
||||
serviceCopy := service.Copy()
|
||||
if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return fmt.Errorf("encrypt service data: %w", err)
|
||||
}
|
||||
|
||||
// Create target type instance outside transaction to avoid variable shadowing
|
||||
targetType := &rpservice.Target{}
|
||||
|
||||
// Use a transaction to ensure atomic updates of the service and its targets
|
||||
err := s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// Delete existing targets
|
||||
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil {
|
||||
if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4869,7 +4900,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Serv
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error {
|
||||
result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID)
|
||||
result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete service from store")
|
||||
@@ -4882,13 +4913,53 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error {
|
||||
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete target from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "target not found for service %s", serviceID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error {
|
||||
result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete targets from store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID retrieves all targets for a given service
|
||||
func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) {
|
||||
var targets []*rpservice.Target
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
result := tx.Where("account_id = ? AND service_id = ?", accountID, serviceID).Find(&targets)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get targets from store: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get targets from store")
|
||||
}
|
||||
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var service *reverseproxy.Service
|
||||
var service *rpservice.Service
|
||||
result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -4906,30 +4977,8 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
result := tx.Find(&serviceList, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get services from store")
|
||||
}
|
||||
|
||||
for _, service := range serviceList {
|
||||
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt service data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return serviceList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||
var service *reverseproxy.Service
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) {
|
||||
var service *rpservice.Service
|
||||
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -4947,13 +4996,13 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
var serviceList []*rpservice.Service
|
||||
result := tx.Find(&serviceList)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
@@ -4969,13 +5018,13 @@ func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength
|
||||
return serviceList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
var serviceList []*rpservice.Service
|
||||
result := tx.Find(&serviceList, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
@@ -4991,6 +5040,99 @@ func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingS
|
||||
return serviceList, nil
|
||||
}
|
||||
|
||||
// RenewEphemeralService updates the last_renewed_at timestamp for an ephemeral service.
|
||||
func (s *SqlStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
|
||||
result := s.db.Model(&rpservice.Service{}).
|
||||
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
|
||||
Update("meta_last_renewed_at", time.Now())
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to renew ephemeral service: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "renew ephemeral service")
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetExpiredEphemeralServices returns ephemeral services whose last renewal exceeds the given TTL.
|
||||
// Only the fields needed for reaping are selected. The limit parameter caps the batch size to
|
||||
// avoid loading too many rows in a single tick. Rows with empty source_peer are excluded to
|
||||
// skip malformed legacy data.
|
||||
func (s *SqlStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error) {
|
||||
cutoff := time.Now().Add(-ttl)
|
||||
var services []*rpservice.Service
|
||||
result := s.db.
|
||||
Select("id", "account_id", "source_peer", "domain").
|
||||
Where("source = ? AND source_peer <> '' AND meta_last_renewed_at < ?", rpservice.SourceEphemeral, cutoff).
|
||||
Limit(limit).
|
||||
Find(&services)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get expired ephemeral services: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "get expired ephemeral services")
|
||||
}
|
||||
return services, nil
|
||||
}
|
||||
|
||||
// CountEphemeralServicesByPeer returns the count of ephemeral services for a specific peer.
|
||||
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
|
||||
// The locking is applied via a row-level SELECT ... FOR UPDATE (not on the aggregate) to
|
||||
// stay compatible with Postgres, which disallows FOR UPDATE on COUNT(*).
|
||||
func (s *SqlStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
|
||||
if lockStrength == LockingStrengthNone {
|
||||
var count int64
|
||||
result := s.db.Model(&rpservice.Service{}).
|
||||
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
|
||||
Count(&count)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
|
||||
return 0, status.Errorf(status.Internal, "count ephemeral services")
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
var ids []string
|
||||
result := s.db.Model(&rpservice.Service{}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Select("id").
|
||||
Where("account_id = ? AND source_peer = ? AND source = ?", accountID, peerID, rpservice.SourceEphemeral).
|
||||
Pluck("id", &ids)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to count ephemeral services: %v", result.Error)
|
||||
return 0, status.Errorf(status.Internal, "count ephemeral services")
|
||||
}
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
// EphemeralServiceExists checks if an ephemeral service exists for the given peer and domain.
|
||||
// Use LockingStrengthUpdate inside a transaction to serialize concurrent create operations.
|
||||
func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
|
||||
if lockStrength == LockingStrengthNone {
|
||||
var count int64
|
||||
result := s.db.Model(&rpservice.Service{}).
|
||||
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
|
||||
Count(&count)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
|
||||
return false, status.Errorf(status.Internal, "check ephemeral service existence")
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
var id string
|
||||
result := s.db.Model(&rpservice.Service{}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Select("id").
|
||||
Where("account_id = ? AND source_peer = ? AND domain = ? AND source = ?", accountID, peerID, domain, rpservice.SourceEphemeral).
|
||||
Limit(1).
|
||||
Pluck("id", &id)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to check ephemeral service existence: %v", result.Error)
|
||||
return false, status.Errorf(status.Internal, "check ephemeral service existence")
|
||||
}
|
||||
return id != "", nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) {
|
||||
tx := s.db
|
||||
|
||||
@@ -5203,13 +5345,13 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces
|
||||
return query
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) {
|
||||
func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var target *reverseproxy.Target
|
||||
var target *rpservice.Target
|
||||
result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -5222,3 +5364,65 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
|
||||
|
||||
return target, nil
|
||||
}
|
||||
|
||||
// SaveProxy saves or updates a proxy in the database
|
||||
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
result := s.db.WithContext(ctx).Save(p)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save proxy")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND status = ?", proxyID, "connected").
|
||||
Update("last_seen", time.Now())
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to update proxy heartbeat")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)).
|
||||
Distinct("cluster_address").
|
||||
Pluck("cluster_address", &addresses)
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses")
|
||||
}
|
||||
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration
|
||||
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
cutoffTime := time.Now().Add(-inactivityDuration)
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
Where("last_seen < ?", cutoffTime).
|
||||
Delete(&proxy.Proxy{})
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to cleanup stale proxies")
|
||||
}
|
||||
|
||||
if result.RowsAffected > 0 {
|
||||
log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -264,7 +264,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
|
||||
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
|
||||
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
|
||||
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||
&types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{},
|
||||
&types.AccountOnboarding{}, &service.Service{}, &service.Target{},
|
||||
}
|
||||
|
||||
for i := len(models) - 1; i >= 0; i-- {
|
||||
|
||||
@@ -25,9 +25,10 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
@@ -252,14 +253,18 @@ type Store interface {
|
||||
MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
|
||||
GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error)
|
||||
|
||||
CreateService(ctx context.Context, service *reverseproxy.Service) error
|
||||
UpdateService(ctx context.Context, service *reverseproxy.Service) error
|
||||
CreateService(ctx context.Context, service *rpservice.Service) error
|
||||
UpdateService(ctx context.Context, service *rpservice.Service) error
|
||||
DeleteService(ctx context.Context, accountID, serviceID string) error
|
||||
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error)
|
||||
GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
|
||||
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error)
|
||||
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error)
|
||||
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
|
||||
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error)
|
||||
GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error)
|
||||
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error)
|
||||
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error)
|
||||
|
||||
RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error
|
||||
GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*rpservice.Service, error)
|
||||
CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error)
|
||||
EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error)
|
||||
|
||||
GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error)
|
||||
ListFreeDomains(ctx context.Context, accountID string) ([]string, error)
|
||||
@@ -271,7 +276,17 @@ type Store interface {
|
||||
CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error
|
||||
GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error)
|
||||
DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error)
|
||||
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error)
|
||||
GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error)
|
||||
GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error)
|
||||
DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error
|
||||
DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error
|
||||
|
||||
SaveProxy(ctx context.Context, proxy *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -12,9 +12,10 @@ import (
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
dns "github.com/netbirdio/netbird/dns"
|
||||
reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
zones "github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
records "github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
types "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -150,6 +151,20 @@ func (mr *MockStoreMockRecorder) ApproveAccountPeers(ctx, accountID interface{})
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID)
|
||||
}
|
||||
|
||||
// CleanupStaleProxies mocks base method.
|
||||
func (m *MockStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CleanupStaleProxies", ctx, inactivityDuration)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CleanupStaleProxies indicates an expected call of CleanupStaleProxies.
|
||||
func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockStore) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -193,6 +208,21 @@ func (mr *MockStoreMockRecorder) CountAccountsByPrivateDomain(ctx, domain interf
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAccountsByPrivateDomain", reflect.TypeOf((*MockStore)(nil).CountAccountsByPrivateDomain), ctx, domain)
|
||||
}
|
||||
|
||||
// CountEphemeralServicesByPeer mocks base method.
|
||||
func (m *MockStore) CountEphemeralServicesByPeer(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountEphemeralServicesByPeer", ctx, lockStrength, accountID, peerID)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountEphemeralServicesByPeer indicates an expected call of CountEphemeralServicesByPeer.
|
||||
func (mr *MockStoreMockRecorder) CountEphemeralServicesByPeer(ctx, lockStrength, accountID, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountEphemeralServicesByPeer", reflect.TypeOf((*MockStore)(nil).CountEphemeralServicesByPeer), ctx, lockStrength, accountID, peerID)
|
||||
}
|
||||
|
||||
// CreateAccessLog mocks base method.
|
||||
func (m *MockStore) CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -293,7 +323,7 @@ func (mr *MockStoreMockRecorder) CreatePolicy(ctx, policy interface{}) *gomock.C
|
||||
}
|
||||
|
||||
// CreateService mocks base method.
|
||||
func (m *MockStore) CreateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (m *MockStore) CreateService(ctx context.Context, service *service.Service) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateService", ctx, service)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -559,6 +589,20 @@ func (mr *MockStoreMockRecorder) DeleteService(ctx, accountID, serviceID interfa
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockStore)(nil).DeleteService), ctx, accountID, serviceID)
|
||||
}
|
||||
|
||||
// DeleteServiceTargets mocks base method.
|
||||
func (m *MockStore) DeleteServiceTargets(ctx context.Context, accountID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteServiceTargets", ctx, accountID, serviceID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteServiceTargets indicates an expected call of DeleteServiceTargets.
|
||||
func (mr *MockStoreMockRecorder) DeleteServiceTargets(ctx, accountID, serviceID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteServiceTargets", reflect.TypeOf((*MockStore)(nil).DeleteServiceTargets), ctx, accountID, serviceID)
|
||||
}
|
||||
|
||||
// DeleteSetupKey mocks base method.
|
||||
func (m *MockStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -573,6 +617,20 @@ func (mr *MockStoreMockRecorder) DeleteSetupKey(ctx, accountID, keyID interface{
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSetupKey", reflect.TypeOf((*MockStore)(nil).DeleteSetupKey), ctx, accountID, keyID)
|
||||
}
|
||||
|
||||
// DeleteTarget mocks base method.
|
||||
func (m *MockStore) DeleteTarget(ctx context.Context, accountID, serviceID string, targetID uint) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteTarget", ctx, accountID, serviceID, targetID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteTarget indicates an expected call of DeleteTarget.
|
||||
func (mr *MockStoreMockRecorder) DeleteTarget(ctx, accountID, serviceID, targetID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTarget", reflect.TypeOf((*MockStore)(nil).DeleteTarget), ctx, accountID, serviceID, targetID)
|
||||
}
|
||||
|
||||
// DeleteTokenID2UserIDIndex mocks base method.
|
||||
func (m *MockStore) DeleteTokenID2UserIDIndex(tokenID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -643,6 +701,21 @@ func (mr *MockStoreMockRecorder) DeleteZoneDNSRecords(ctx, accountID, zoneID int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteZoneDNSRecords", reflect.TypeOf((*MockStore)(nil).DeleteZoneDNSRecords), ctx, accountID, zoneID)
|
||||
}
|
||||
|
||||
// EphemeralServiceExists mocks base method.
|
||||
func (m *MockStore) EphemeralServiceExists(ctx context.Context, lockStrength LockingStrength, accountID, peerID, domain string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "EphemeralServiceExists", ctx, lockStrength, accountID, peerID, domain)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// EphemeralServiceExists indicates an expected call of EphemeralServiceExists.
|
||||
func (mr *MockStoreMockRecorder) EphemeralServiceExists(ctx, lockStrength, accountID, peerID, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EphemeralServiceExists", reflect.TypeOf((*MockStore)(nil).EphemeralServiceExists), ctx, lockStrength, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// ExecuteInTransaction mocks base method.
|
||||
func (m *MockStore) ExecuteInTransaction(ctx context.Context, f func(Store) error) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1095,10 +1168,10 @@ func (mr *MockStoreMockRecorder) GetAccountRoutes(ctx, lockStrength, accountID i
|
||||
}
|
||||
|
||||
// GetAccountServices mocks base method.
|
||||
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret0, _ := ret[0].([]*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1109,21 +1182,6 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetServicesByAccountID mocks base method.
|
||||
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAccountSettings mocks base method.
|
||||
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1214,6 +1272,21 @@ func (mr *MockStoreMockRecorder) GetAccountsCounter(ctx interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx)
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses mocks base method.
|
||||
func (m *MockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveProxyClusterAddresses", ctx)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveProxyClusterAddresses indicates an expected call of GetActiveProxyClusterAddresses.
|
||||
func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// GetAllAccounts mocks base method.
|
||||
func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1288,6 +1361,22 @@ func (mr *MockStoreMockRecorder) GetCustomDomain(ctx, accountID, domainID interf
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomain", reflect.TypeOf((*MockStore)(nil).GetCustomDomain), ctx, accountID, domainID)
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts mocks base method.
|
||||
func (m *MockStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetCustomDomainsCounts", ctx)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(int64)
|
||||
ret2, _ := ret[2].(error)
|
||||
return ret0, ret1, ret2
|
||||
}
|
||||
|
||||
// GetCustomDomainsCounts indicates an expected call of GetCustomDomainsCounts.
|
||||
func (mr *MockStoreMockRecorder) GetCustomDomainsCounts(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCustomDomainsCounts", reflect.TypeOf((*MockStore)(nil).GetCustomDomainsCounts), ctx)
|
||||
}
|
||||
|
||||
// GetDNSRecordByID mocks base method.
|
||||
func (m *MockStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1303,6 +1392,21 @@ func (mr *MockStoreMockRecorder) GetDNSRecordByID(ctx, lockStrength, accountID,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDNSRecordByID", reflect.TypeOf((*MockStore)(nil).GetDNSRecordByID), ctx, lockStrength, accountID, zoneID, recordID)
|
||||
}
|
||||
|
||||
// GetExpiredEphemeralServices mocks base method.
|
||||
func (m *MockStore) GetExpiredEphemeralServices(ctx context.Context, ttl time.Duration, limit int) ([]*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetExpiredEphemeralServices", ctx, ttl, limit)
|
||||
ret0, _ := ret[0].([]*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetExpiredEphemeralServices indicates an expected call of GetExpiredEphemeralServices.
|
||||
func (mr *MockStoreMockRecorder) GetExpiredEphemeralServices(ctx, ttl, limit interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetExpiredEphemeralServices", reflect.TypeOf((*MockStore)(nil).GetExpiredEphemeralServices), ctx, ttl, limit)
|
||||
}
|
||||
|
||||
// GetGroupByID mocks base method.
|
||||
func (m *MockStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1828,10 +1932,10 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain)
|
||||
ret0, _ := ret[0].(*reverseproxy.Service)
|
||||
ret0, _ := ret[0].(*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1843,10 +1947,10 @@ func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain inter
|
||||
}
|
||||
|
||||
// GetServiceByID mocks base method.
|
||||
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID)
|
||||
ret0, _ := ret[0].(*reverseproxy.Service)
|
||||
ret0, _ := ret[0].(*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1858,10 +1962,10 @@ func (mr *MockStoreMockRecorder) GetServiceByID(ctx, lockStrength, accountID, se
|
||||
}
|
||||
|
||||
// GetServiceTargetByTargetID mocks base method.
|
||||
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*reverseproxy.Target, error) {
|
||||
func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*service.Target, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID)
|
||||
ret0, _ := ret[0].(*reverseproxy.Target)
|
||||
ret0, _ := ret[0].(*service.Target)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1873,10 +1977,10 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a
|
||||
}
|
||||
|
||||
// GetServices mocks base method.
|
||||
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) {
|
||||
func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret0, _ := ret[0].([]*service.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@@ -1946,6 +2050,21 @@ func (mr *MockStoreMockRecorder) GetTakenIPs(ctx, lockStrength, accountId interf
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTakenIPs", reflect.TypeOf((*MockStore)(nil).GetTakenIPs), ctx, lockStrength, accountId)
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID mocks base method.
|
||||
func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*service.Target, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetTargetsByServiceID", ctx, lockStrength, accountID, serviceID)
|
||||
ret0, _ := ret[0].([]*service.Target)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetTargetsByServiceID indicates an expected call of GetTargetsByServiceID.
|
||||
func (mr *MockStoreMockRecorder) GetTargetsByServiceID(ctx, lockStrength, accountID, serviceID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTargetsByServiceID", reflect.TypeOf((*MockStore)(nil).GetTargetsByServiceID), ctx, lockStrength, accountID, serviceID)
|
||||
}
|
||||
|
||||
// GetTokenIDByHashedToken mocks base method.
|
||||
func (m *MockStore) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2327,6 +2446,20 @@ func (mr *MockStoreMockRecorder) RemoveResourceFromGroup(ctx, accountId, groupID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveResourceFromGroup", reflect.TypeOf((*MockStore)(nil).RemoveResourceFromGroup), ctx, accountId, groupID, resourceID)
|
||||
}
|
||||
|
||||
// RenewEphemeralService mocks base method.
|
||||
func (m *MockStore) RenewEphemeralService(ctx context.Context, accountID, peerID, domain string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RenewEphemeralService", ctx, accountID, peerID, domain)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RenewEphemeralService indicates an expected call of RenewEphemeralService.
|
||||
func (mr *MockStoreMockRecorder) RenewEphemeralService(ctx, accountID, peerID, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewEphemeralService", reflect.TypeOf((*MockStore)(nil).RenewEphemeralService), ctx, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// RevokeProxyAccessToken mocks base method.
|
||||
func (m *MockStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2551,6 +2684,20 @@ func (mr *MockStoreMockRecorder) SavePostureChecks(ctx, postureCheck interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck)
|
||||
}
|
||||
|
||||
// SaveProxy mocks base method.
|
||||
func (m *MockStore) SaveProxy(ctx context.Context, proxy *proxy.Proxy) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SaveProxy", ctx, proxy)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SaveProxy indicates an expected call of SaveProxy.
|
||||
func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy)
|
||||
}
|
||||
|
||||
// SaveProxyAccessToken mocks base method.
|
||||
func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2746,8 +2893,22 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups)
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat mocks base method.
|
||||
func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat.
|
||||
func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID)
|
||||
}
|
||||
|
||||
// UpdateService mocks base method.
|
||||
func (m *MockStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error {
|
||||
func (m *MockStore) UpdateService(ctx context.Context, service *service.Service) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateService", ctx, service)
|
||||
ret0, _ := ret[0].(error)
|
||||
|
||||
185
management/server/telemetry/account_aggregator.go
Normal file
185
management/server/telemetry/account_aggregator.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/metric/metricdata"
|
||||
)
|
||||
|
||||
// AccountDurationAggregator uses OpenTelemetry histograms per account to calculate P95
|
||||
// without publishing individual account labels
|
||||
type AccountDurationAggregator struct {
|
||||
mu sync.RWMutex
|
||||
accounts map[string]*accountHistogram
|
||||
meterProvider *sdkmetric.MeterProvider
|
||||
manualReader *sdkmetric.ManualReader
|
||||
|
||||
FlushInterval time.Duration
|
||||
MaxAge time.Duration
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
type accountHistogram struct {
|
||||
histogram metric.Int64Histogram
|
||||
lastUpdate time.Time
|
||||
}
|
||||
|
||||
// NewAccountDurationAggregator creates aggregator using OTel histograms
|
||||
func NewAccountDurationAggregator(ctx context.Context, flushInterval, maxAge time.Duration) *AccountDurationAggregator {
|
||||
manualReader := sdkmetric.NewManualReader(
|
||||
sdkmetric.WithTemporalitySelector(func(kind sdkmetric.InstrumentKind) metricdata.Temporality {
|
||||
return metricdata.DeltaTemporality
|
||||
}),
|
||||
)
|
||||
|
||||
meterProvider := sdkmetric.NewMeterProvider(
|
||||
sdkmetric.WithReader(manualReader),
|
||||
)
|
||||
|
||||
return &AccountDurationAggregator{
|
||||
accounts: make(map[string]*accountHistogram),
|
||||
meterProvider: meterProvider,
|
||||
manualReader: manualReader,
|
||||
FlushInterval: flushInterval,
|
||||
MaxAge: maxAge,
|
||||
ctx: ctx,
|
||||
}
|
||||
}
|
||||
|
||||
// Record adds a duration for an account using OTel histogram
|
||||
func (a *AccountDurationAggregator) Record(accountID string, duration time.Duration) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
accHist, exists := a.accounts[accountID]
|
||||
if !exists {
|
||||
meter := a.meterProvider.Meter("account-aggregator")
|
||||
histogram, err := meter.Int64Histogram(
|
||||
"sync_duration_per_account",
|
||||
metric.WithUnit("milliseconds"),
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
accHist = &accountHistogram{
|
||||
histogram: histogram,
|
||||
}
|
||||
a.accounts[accountID] = accHist
|
||||
}
|
||||
|
||||
accHist.histogram.Record(a.ctx, duration.Milliseconds(),
|
||||
metric.WithAttributes(attribute.String("account_id", accountID)))
|
||||
accHist.lastUpdate = time.Now()
|
||||
}
|
||||
|
||||
// FlushAndGetP95s extracts P95 from each account's histogram
|
||||
func (a *AccountDurationAggregator) FlushAndGetP95s() []int64 {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
var rm metricdata.ResourceMetrics
|
||||
err := a.manualReader.Collect(a.ctx, &rm)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
p95s := make([]int64, 0, len(a.accounts))
|
||||
|
||||
for _, scopeMetrics := range rm.ScopeMetrics {
|
||||
for _, metric := range scopeMetrics.Metrics {
|
||||
histogramData, ok := metric.Data.(metricdata.Histogram[int64])
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, dataPoint := range histogramData.DataPoints {
|
||||
a.processDataPoint(dataPoint, now, &p95s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a.cleanupStaleAccounts(now)
|
||||
|
||||
return p95s
|
||||
}
|
||||
|
||||
// processDataPoint extracts P95 from a single histogram data point
|
||||
func (a *AccountDurationAggregator) processDataPoint(dataPoint metricdata.HistogramDataPoint[int64], now time.Time, p95s *[]int64) {
|
||||
accountID := extractAccountID(dataPoint)
|
||||
if accountID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if p95 := calculateP95FromHistogram(dataPoint); p95 > 0 {
|
||||
*p95s = append(*p95s, p95)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleAccounts removes accounts that haven't been updated recently
|
||||
func (a *AccountDurationAggregator) cleanupStaleAccounts(now time.Time) {
|
||||
for accountID := range a.accounts {
|
||||
if a.isStaleAccount(accountID, now) {
|
||||
delete(a.accounts, accountID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractAccountID retrieves the account_id from histogram data point attributes
|
||||
func extractAccountID(dp metricdata.HistogramDataPoint[int64]) string {
|
||||
for _, attr := range dp.Attributes.ToSlice() {
|
||||
if attr.Key == "account_id" {
|
||||
return attr.Value.AsString()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// isStaleAccount checks if an account hasn't been updated recently
|
||||
func (a *AccountDurationAggregator) isStaleAccount(accountID string, now time.Time) bool {
|
||||
accHist, exists := a.accounts[accountID]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
return now.Sub(accHist.lastUpdate) > a.MaxAge
|
||||
}
|
||||
|
||||
// calculateP95FromHistogram computes P95 from OTel histogram data
|
||||
func calculateP95FromHistogram(dp metricdata.HistogramDataPoint[int64]) int64 {
|
||||
if dp.Count == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
targetCount := uint64(math.Ceil(float64(dp.Count) * 0.95))
|
||||
if targetCount == 0 {
|
||||
targetCount = 1
|
||||
}
|
||||
var cumulativeCount uint64
|
||||
|
||||
for i, bucketCount := range dp.BucketCounts {
|
||||
cumulativeCount += bucketCount
|
||||
if cumulativeCount >= targetCount {
|
||||
if i < len(dp.Bounds) {
|
||||
return int64(dp.Bounds[i])
|
||||
}
|
||||
if maxVal, defined := dp.Max.Value(); defined {
|
||||
return maxVal
|
||||
}
|
||||
return dp.Sum / int64(dp.Count)
|
||||
}
|
||||
}
|
||||
|
||||
return dp.Sum / int64(dp.Count)
|
||||
}
|
||||
|
||||
// Shutdown cleans up resources
|
||||
func (a *AccountDurationAggregator) Shutdown() error {
|
||||
return a.meterProvider.Shutdown(a.ctx)
|
||||
}
|
||||
219
management/server/telemetry/account_aggregator_test.go
Normal file
219
management/server/telemetry/account_aggregator_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDeltaTemporality_P95ReflectsCurrentWindow(t *testing.T) {
|
||||
// Verify that with delta temporality, each flush window only reflects
|
||||
// recordings since the last flush — not all-time data.
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
// Window 1: Record 100 slow requests (500ms each)
|
||||
for range 100 {
|
||||
agg.Record("account-A", 500*time.Millisecond)
|
||||
}
|
||||
|
||||
p95sWindow1 := agg.FlushAndGetP95s()
|
||||
require.Len(t, p95sWindow1, 1, "should have P95 for one account")
|
||||
firstP95 := p95sWindow1[0]
|
||||
assert.GreaterOrEqual(t, firstP95, int64(200),
|
||||
"first window P95 should reflect the 500ms recordings")
|
||||
|
||||
// Window 2: Record 100 FAST requests (10ms each)
|
||||
for range 100 {
|
||||
agg.Record("account-A", 10*time.Millisecond)
|
||||
}
|
||||
|
||||
p95sWindow2 := agg.FlushAndGetP95s()
|
||||
require.Len(t, p95sWindow2, 1, "should have P95 for one account")
|
||||
secondP95 := p95sWindow2[0]
|
||||
|
||||
// With delta temporality the P95 should drop significantly because
|
||||
// the first window's slow recordings are no longer included.
|
||||
assert.Less(t, secondP95, firstP95,
|
||||
"second window P95 should be lower than first — delta temporality "+
|
||||
"ensures each window only reflects recent recordings")
|
||||
}
|
||||
|
||||
func TestEqualWeightPerAccount(t *testing.T) {
|
||||
// Verify that each account contributes exactly one P95 value,
|
||||
// regardless of how many requests it made.
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
// Account A: 10,000 requests at 500ms (noisy customer)
|
||||
for range 10000 {
|
||||
agg.Record("account-A", 500*time.Millisecond)
|
||||
}
|
||||
|
||||
// Accounts B, C, D: 10 requests each at 50ms (normal customers)
|
||||
for _, id := range []string{"account-B", "account-C", "account-D"} {
|
||||
for range 10 {
|
||||
agg.Record(id, 50*time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
p95s := agg.FlushAndGetP95s()
|
||||
|
||||
// Should get exactly 4 P95 values — one per account
|
||||
assert.Len(t, p95s, 4, "each account should contribute exactly one P95")
|
||||
}
|
||||
|
||||
func TestStaleAccountEviction(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
// Use a very short MaxAge so we can test staleness
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
agg.Record("account-A", 100*time.Millisecond)
|
||||
agg.Record("account-B", 200*time.Millisecond)
|
||||
|
||||
// Both accounts should appear
|
||||
p95s := agg.FlushAndGetP95s()
|
||||
assert.Len(t, p95s, 2, "both accounts should have P95 values")
|
||||
|
||||
// Wait for account-A to become stale, then only update account-B
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
agg.Record("account-B", 200*time.Millisecond)
|
||||
|
||||
p95s = agg.FlushAndGetP95s()
|
||||
assert.Len(t, p95s, 1, "both accounts should have P95 values")
|
||||
|
||||
// account-A should have been evicted from the accounts map
|
||||
agg.mu.RLock()
|
||||
_, accountAExists := agg.accounts["account-A"]
|
||||
_, accountBExists := agg.accounts["account-B"]
|
||||
agg.mu.RUnlock()
|
||||
|
||||
assert.False(t, accountAExists, "stale account-A should be evicted from map")
|
||||
assert.True(t, accountBExists, "active account-B should remain in map")
|
||||
}
|
||||
|
||||
func TestStaleAccountEviction_DoesNotReappear(t *testing.T) {
|
||||
// Verify that with delta temporality, an evicted stale account does not
|
||||
// reappear in subsequent flushes.
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 50*time.Millisecond)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
agg.Record("account-stale", 100*time.Millisecond)
|
||||
|
||||
// Wait for it to become stale
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
|
||||
// First flush: should detect staleness and evict
|
||||
_ = agg.FlushAndGetP95s()
|
||||
|
||||
agg.mu.RLock()
|
||||
_, exists := agg.accounts["account-stale"]
|
||||
agg.mu.RUnlock()
|
||||
assert.False(t, exists, "account should be evicted after first flush")
|
||||
|
||||
// Second flush: with delta temporality, the stale account should NOT reappear
|
||||
p95sSecond := agg.FlushAndGetP95s()
|
||||
assert.Empty(t, p95sSecond,
|
||||
"evicted account should not reappear in subsequent flushes with delta temporality")
|
||||
}
|
||||
|
||||
func TestP95Calculation_SingleSample(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
agg.Record("account-A", 150*time.Millisecond)
|
||||
|
||||
p95s := agg.FlushAndGetP95s()
|
||||
require.Len(t, p95s, 1)
|
||||
// With a single sample, P95 should be the bucket bound containing 150ms
|
||||
assert.Greater(t, p95s[0], int64(0), "P95 of a single sample should be positive")
|
||||
}
|
||||
|
||||
func TestP95Calculation_AllSameValue(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
// All samples are 100ms — P95 should be the bucket bound containing 100ms
|
||||
for range 100 {
|
||||
agg.Record("account-A", 100*time.Millisecond)
|
||||
}
|
||||
|
||||
p95s := agg.FlushAndGetP95s()
|
||||
require.Len(t, p95s, 1)
|
||||
assert.Greater(t, p95s[0], int64(0))
|
||||
}
|
||||
|
||||
func TestMultipleAccounts_IndependentP95s(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
agg := NewAccountDurationAggregator(ctx, time.Minute, 5*time.Minute)
|
||||
defer func(agg *AccountDurationAggregator) {
|
||||
err := agg.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to shutdown aggregator: %v", err)
|
||||
}
|
||||
}(agg)
|
||||
|
||||
// Account A: all fast (10ms)
|
||||
for range 100 {
|
||||
agg.Record("account-fast", 10*time.Millisecond)
|
||||
}
|
||||
|
||||
// Account B: all slow (5000ms)
|
||||
for range 100 {
|
||||
agg.Record("account-slow", 5000*time.Millisecond)
|
||||
}
|
||||
|
||||
p95s := agg.FlushAndGetP95s()
|
||||
require.Len(t, p95s, 2, "should have two P95 values")
|
||||
|
||||
// Find min and max — they should differ significantly
|
||||
minP95 := p95s[0]
|
||||
maxP95 := p95s[1]
|
||||
if minP95 > maxP95 {
|
||||
minP95, maxP95 = maxP95, minP95
|
||||
}
|
||||
|
||||
assert.Less(t, minP95, int64(1000),
|
||||
"fast account P95 should be well under 1000ms")
|
||||
assert.Greater(t, maxP95, int64(1000),
|
||||
"slow account P95 should be well over 1000ms")
|
||||
}
|
||||
@@ -13,18 +13,24 @@ const HighLatencyThreshold = time.Second * 7
|
||||
|
||||
// GRPCMetrics are gRPC server metrics
|
||||
type GRPCMetrics struct {
|
||||
meter metric.Meter
|
||||
syncRequestsCounter metric.Int64Counter
|
||||
syncRequestsBlockedCounter metric.Int64Counter
|
||||
loginRequestsCounter metric.Int64Counter
|
||||
loginRequestsBlockedCounter metric.Int64Counter
|
||||
loginRequestHighLatencyCounter metric.Int64Counter
|
||||
getKeyRequestsCounter metric.Int64Counter
|
||||
activeStreamsGauge metric.Int64ObservableGauge
|
||||
syncRequestDuration metric.Int64Histogram
|
||||
loginRequestDuration metric.Int64Histogram
|
||||
channelQueueLength metric.Int64Histogram
|
||||
ctx context.Context
|
||||
meter metric.Meter
|
||||
syncRequestsCounter metric.Int64Counter
|
||||
syncRequestsBlockedCounter metric.Int64Counter
|
||||
loginRequestsCounter metric.Int64Counter
|
||||
loginRequestsBlockedCounter metric.Int64Counter
|
||||
loginRequestHighLatencyCounter metric.Int64Counter
|
||||
getKeyRequestsCounter metric.Int64Counter
|
||||
activeStreamsGauge metric.Int64ObservableGauge
|
||||
syncRequestDuration metric.Int64Histogram
|
||||
syncRequestDurationP95ByAccount metric.Int64Histogram
|
||||
loginRequestDuration metric.Int64Histogram
|
||||
loginRequestDurationP95ByAccount metric.Int64Histogram
|
||||
channelQueueLength metric.Int64Histogram
|
||||
ctx context.Context
|
||||
|
||||
// Per-account aggregation
|
||||
syncDurationAggregator *AccountDurationAggregator
|
||||
loginDurationAggregator *AccountDurationAggregator
|
||||
}
|
||||
|
||||
// NewGRPCMetrics creates new GRPCMetrics struct and registers common metrics of the gRPC server
|
||||
@@ -93,6 +99,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.sync.request.duration.p95.by.account.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("P95 duration of sync requests aggregated per account - each data point represents one account's P95"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginRequestDuration, err := meter.Int64Histogram("management.grpc.login.request.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("Duration of the login gRPC requests from the peers to authenticate and receive initial configuration and relay credentials"),
|
||||
@@ -101,6 +115,14 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginRequestDurationP95ByAccount, err := meter.Int64Histogram("management.grpc.login.request.duration.p95.by.account.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithDescription("P95 duration of login requests aggregated per account - each data point represents one account's P95"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// We use histogram here as we have multiple channel at the same time and we want to see a slice at any given time
|
||||
// Then we should be able to extract min, manx, mean and the percentiles.
|
||||
// TODO(yury): This needs custom bucketing as we are interested in the values from 0 to server.channelBufferSize (100)
|
||||
@@ -113,20 +135,32 @@ func NewGRPCMetrics(ctx context.Context, meter metric.Meter) (*GRPCMetrics, erro
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GRPCMetrics{
|
||||
meter: meter,
|
||||
syncRequestsCounter: syncRequestsCounter,
|
||||
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
|
||||
loginRequestsCounter: loginRequestsCounter,
|
||||
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
|
||||
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
|
||||
getKeyRequestsCounter: getKeyRequestsCounter,
|
||||
activeStreamsGauge: activeStreamsGauge,
|
||||
syncRequestDuration: syncRequestDuration,
|
||||
loginRequestDuration: loginRequestDuration,
|
||||
channelQueueLength: channelQueue,
|
||||
ctx: ctx,
|
||||
}, err
|
||||
syncDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
|
||||
loginDurationAggregator := NewAccountDurationAggregator(ctx, 60*time.Second, 5*time.Minute)
|
||||
|
||||
grpcMetrics := &GRPCMetrics{
|
||||
meter: meter,
|
||||
syncRequestsCounter: syncRequestsCounter,
|
||||
syncRequestsBlockedCounter: syncRequestsBlockedCounter,
|
||||
loginRequestsCounter: loginRequestsCounter,
|
||||
loginRequestsBlockedCounter: loginRequestsBlockedCounter,
|
||||
loginRequestHighLatencyCounter: loginRequestHighLatencyCounter,
|
||||
getKeyRequestsCounter: getKeyRequestsCounter,
|
||||
activeStreamsGauge: activeStreamsGauge,
|
||||
syncRequestDuration: syncRequestDuration,
|
||||
syncRequestDurationP95ByAccount: syncRequestDurationP95ByAccount,
|
||||
loginRequestDuration: loginRequestDuration,
|
||||
loginRequestDurationP95ByAccount: loginRequestDurationP95ByAccount,
|
||||
channelQueueLength: channelQueue,
|
||||
ctx: ctx,
|
||||
syncDurationAggregator: syncDurationAggregator,
|
||||
loginDurationAggregator: loginDurationAggregator,
|
||||
}
|
||||
|
||||
go grpcMetrics.startSyncP95Flusher()
|
||||
go grpcMetrics.startLoginP95Flusher()
|
||||
|
||||
return grpcMetrics, err
|
||||
}
|
||||
|
||||
// CountSyncRequest counts the number of gRPC sync requests coming to the gRPC API
|
||||
@@ -157,6 +191,9 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestBlocked() {
|
||||
// CountLoginRequestDuration counts the duration of the login gRPC requests
|
||||
func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration, accountID string) {
|
||||
grpcMetrics.loginRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
||||
|
||||
grpcMetrics.loginDurationAggregator.Record(accountID, duration)
|
||||
|
||||
if duration > HighLatencyThreshold {
|
||||
grpcMetrics.loginRequestHighLatencyCounter.Add(grpcMetrics.ctx, 1, metric.WithAttributes(attribute.String(AccountIDLabel, accountID)))
|
||||
}
|
||||
@@ -165,6 +202,44 @@ func (grpcMetrics *GRPCMetrics) CountLoginRequestDuration(duration time.Duration
|
||||
// CountSyncRequestDuration counts the duration of the sync gRPC requests
|
||||
func (grpcMetrics *GRPCMetrics) CountSyncRequestDuration(duration time.Duration, accountID string) {
|
||||
grpcMetrics.syncRequestDuration.Record(grpcMetrics.ctx, duration.Milliseconds())
|
||||
|
||||
grpcMetrics.syncDurationAggregator.Record(accountID, duration)
|
||||
}
|
||||
|
||||
// startSyncP95Flusher periodically flushes per-account sync P95 values to the histogram
|
||||
func (grpcMetrics *GRPCMetrics) startSyncP95Flusher() {
|
||||
ticker := time.NewTicker(grpcMetrics.syncDurationAggregator.FlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-grpcMetrics.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p95s := grpcMetrics.syncDurationAggregator.FlushAndGetP95s()
|
||||
for _, p95 := range p95s {
|
||||
grpcMetrics.syncRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startLoginP95Flusher periodically flushes per-account login P95 values to the histogram
|
||||
func (grpcMetrics *GRPCMetrics) startLoginP95Flusher() {
|
||||
ticker := time.NewTicker(grpcMetrics.loginDurationAggregator.FlushInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-grpcMetrics.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p95s := grpcMetrics.loginDurationAggregator.FlushAndGetP95s()
|
||||
for _, p95 := range p95s {
|
||||
grpcMetrics.loginRequestDurationP95ByAccount.Record(grpcMetrics.ctx, p95)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterConnectedStreams registers a function that collects number of active streams and feeds it to the metrics gauge.
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -100,7 +100,7 @@ type Account struct {
|
||||
NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"`
|
||||
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
|
||||
Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"`
|
||||
Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
|
||||
// Settings is a dictionary of Account settings
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
|
||||
@@ -906,7 +906,7 @@ func (a *Account) Copy() *Account {
|
||||
networkResources = append(networkResources, resource.Copy())
|
||||
}
|
||||
|
||||
services := []*reverseproxy.Service{}
|
||||
services := []*service.Service{}
|
||||
for _, service := range a.Services {
|
||||
services = append(services, service.Copy())
|
||||
}
|
||||
@@ -1814,7 +1814,7 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
|
||||
func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) {
|
||||
for _, target := range service.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
@@ -1823,7 +1823,7 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *rever
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) {
|
||||
func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) {
|
||||
port, ok := a.resolveTargetPort(ctx, target)
|
||||
if !ok {
|
||||
return
|
||||
@@ -1840,7 +1840,7 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *revers
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) {
|
||||
func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) {
|
||||
if target.Port != 0 {
|
||||
return target.Port, true
|
||||
}
|
||||
@@ -1856,7 +1856,7 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Ta
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
|
||||
func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy {
|
||||
policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path)
|
||||
return &Policy{
|
||||
ID: policyID,
|
||||
|
||||
@@ -47,6 +47,11 @@ type Settings struct {
|
||||
// NetworkRange is the custom network range for that account
|
||||
NetworkRange netip.Prefix `gorm:"serializer:json"`
|
||||
|
||||
// PeerExposeEnabled enables or disables peer-initiated service expose
|
||||
PeerExposeEnabled bool
|
||||
// PeerExposeGroups list of peer group IDs allowed to expose services
|
||||
PeerExposeGroups []string `gorm:"serializer:json"`
|
||||
|
||||
// Extra is a dictionary of Account settings
|
||||
Extra *ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"`
|
||||
|
||||
@@ -80,6 +85,8 @@ func (s *Settings) Copy() *Settings {
|
||||
PeerInactivityExpiration: s.PeerInactivityExpiration,
|
||||
|
||||
RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled,
|
||||
PeerExposeEnabled: s.PeerExposeEnabled,
|
||||
PeerExposeGroups: slices.Clone(s.PeerExposeGroups),
|
||||
LazyConnectionEnabled: s.LazyConnectionEnabled,
|
||||
DNSDomain: s.DNSDomain,
|
||||
NetworkRange: s.NetworkRange,
|
||||
|
||||
@@ -742,6 +742,11 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
|
||||
if err != nil {
|
||||
return false, nil, nil, nil, fmt.Errorf("failed to re-read initiator user in transaction: %w", err)
|
||||
}
|
||||
|
||||
// Ensure the initiator still has admin privileges
|
||||
if initiatorUser.HasAdminPower() && !freshInitiator.HasAdminPower() {
|
||||
return false, nil, nil, nil, status.Errorf(status.PermissionDenied, "initiator role was changed during request processing")
|
||||
}
|
||||
initiatorUser = freshInitiator
|
||||
}
|
||||
|
||||
@@ -872,10 +877,6 @@ func validateUserUpdate(groupsMap map[string]*types.Group, initiatorUser, oldUse
|
||||
return nil
|
||||
}
|
||||
|
||||
if !initiatorUser.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only admins and owners can update users")
|
||||
}
|
||||
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
}
|
||||
|
||||
@@ -2032,27 +2032,6 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidateUserUpdate_RejectsNonAdminInitiator(t *testing.T) {
|
||||
groupsMap := map[string]*types.Group{}
|
||||
|
||||
initiator := &types.User{
|
||||
Id: "initiator",
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
oldUser := &types.User{
|
||||
Id: "target",
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
update := &types.User{
|
||||
Id: "target",
|
||||
Role: types.UserRoleOwner,
|
||||
}
|
||||
|
||||
err := validateUserUpdate(groupsMap, initiator, oldUser, update)
|
||||
require.Error(t, err, "regular user should not be able to promote to owner")
|
||||
assert.Contains(t, err.Error(), "only admins and owners can update users")
|
||||
}
|
||||
|
||||
func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) {
|
||||
s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
@@ -2109,7 +2088,7 @@ func TestProcessUserUpdate_RejectsStaleInitiatorRole(t *testing.T) {
|
||||
})
|
||||
|
||||
require.Error(t, err, "processUserUpdate should reject stale initiator whose role was demoted")
|
||||
assert.Contains(t, err.Error(), "only admins and owners can update users")
|
||||
assert.Contains(t, err.Error(), "initiator role was changed during request processing")
|
||||
|
||||
targetUser, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetID)
|
||||
require.NoError(t, err)
|
||||
|
||||
Reference in New Issue
Block a user