Merge branch 'main' into proto-ipv6-overlay

This commit is contained in:
Viktor Liu
2026-04-17 05:41:53 +02:00
99 changed files with 4694 additions and 930 deletions

View File

@@ -181,7 +181,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []
return modified, newUserAutoGroups, newGroupsToCreate, nil
}
// BuildManager creates a new DefaultAccountManager with a provided Store
// BuildManager creates a new DefaultAccountManager with all dependencies.
func BuildManager(
ctx context.Context,
config *nbconfig.Config,
@@ -199,6 +199,7 @@ func BuildManager(
settingsManager settings.Manager,
permissionsManager permissions.Manager,
disableDefaultPolicy bool,
sharedCacheStore cacheStore.StoreInterface,
) (*DefaultAccountManager, error) {
start := time.Now()
defer func() {
@@ -247,16 +248,12 @@ func BuildManager(
log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter)
}
cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn)
if err != nil {
return nil, fmt.Errorf("getting cache store: %s", err)
}
am.externalCacheManager = nbcache.NewUserDataCache(cacheStore)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore)
am.externalCacheManager = nbcache.NewUserDataCache(sharedCacheStore)
am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, sharedCacheStore)
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
go func() {
err := am.warmupIDPCache(ctx, cacheStore)
err := am.warmupIDPCache(ctx, sharedCacheStore)
if err != nil {
log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err)
// todo retry?

View File

@@ -63,20 +63,11 @@ func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context,
log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
startTime := time.Now()
ac.getAccountRequestCh <- req
select {
case <-ctx.Done():
return nil, ctx.Err()
case ac.getAccountRequestCh <- req:
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-req.ResultChan:
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
}
result := <-req.ResultChan
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
}
func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {

View File

@@ -3173,10 +3173,15 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
ctx := context.Background()
cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, nil, err
}
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, nil, err
}

View File

@@ -17,12 +17,24 @@ import (
// RedisStoreEnvVar is the environment variable that determines if a redis store should be used.
// The value should follow redis URL format. https://github.com/redis/redis-specifications/blob/master/uri/redis.txt
const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS"
const RedisStoreEnvVar = "NB_CACHE_REDIS_ADDRESS"
// legacyIdPCacheRedisEnvVar is the previous environment variable used for IDP cache.
const legacyIdPCacheRedisEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS"
const (
// DefaultStoreMaxTimeout is the default max timeout for the shared cache store.
DefaultStoreMaxTimeout = 7 * 24 * time.Hour
// DefaultStoreCleanupInterval is the default cleanup interval for the shared cache store.
DefaultStoreCleanupInterval = 30 * time.Minute
// DefaultStoreMaxConn is the default max connections for the shared cache store.
DefaultStoreMaxConn = 1000
)
// NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar
// to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store.
func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) {
redisAddr := os.Getenv(RedisStoreEnvVar)
redisAddr := GetAddrFromEnv()
if redisAddr != "" {
return getRedisStore(ctx, redisAddr, maxConn)
}
@@ -30,6 +42,15 @@ func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, ma
return gocache_store.NewGoCache(goc), nil
}
// GetAddrFromEnv returns the redis address from the environment variable RedisStoreEnvVar or its legacy counterpart.
func GetAddrFromEnv() string {
addr := os.Getenv(RedisStoreEnvVar)
if addr == "" {
addr = os.Getenv(legacyIdPCacheRedisEnvVar)
}
return addr
}
func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) {
options, err := redis.ParseURL(redisEnvAddr)
if err != nil {

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -225,11 +226,17 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
peersManager := peers.NewManager(store, permissionsManager)
ctx := context.Background()
cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, err
}
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
}
func createDNSStore(t *testing.T) (store.Store, error) {

View File

@@ -105,6 +105,12 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
router.NetworkID = networkID
router.AccountID = accountID
router.Enabled = true
if err := router.Validate(); err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
router, err = h.routersManager.CreateRouter(r.Context(), userID, router)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -157,6 +163,11 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
router.ID = mux.Vars(r)["routerId"]
router.AccountID = accountID
if err := router.Validate(); err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
router, err = h.routersManager.UpdateRouter(r.Context(), userID, router)
if err != nil {
util.WriteError(r.Context(), err, w)

View File

@@ -22,6 +22,7 @@ import (
nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
@@ -191,11 +192,11 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcServer := newFakeOIDCServer()
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore)
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore)
usersManager := users.NewManager(testStore)

View File

@@ -1170,13 +1170,17 @@ func Test_NetworkRouters_Create(t *testing.T) {
Metric: 100,
Enabled: true,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, router *api.NetworkRouter) {
t.Helper()
assert.NotEmpty(t, router.Id)
assert.Equal(t, peerID, *router.Peer)
assert.Equal(t, 1, len(*router.PeerGroups))
expectedStatus: http.StatusBadRequest,
},
{
name: "Create router without peer and peer_groups",
networkId: "testNetworkId",
requestBody: &api.NetworkRouterRequest{
Masquerade: true,
Metric: 100,
Enabled: true,
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Create router in non-existing network",
@@ -1341,13 +1345,18 @@ func Test_NetworkRouters_Update(t *testing.T) {
Metric: 100,
Enabled: true,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, router *api.NetworkRouter) {
t.Helper()
assert.Equal(t, "testRouterId", router.Id)
assert.Equal(t, peerID, *router.Peer)
assert.Equal(t, 1, len(*router.PeerGroups))
expectedStatus: http.StatusBadRequest,
},
{
name: "Update router without peer and peer_groups",
networkId: "testNetworkId",
routerId: "testRouterId",
requestBody: &api.NetworkRouterRequest{
Masquerade: true,
Metric: 100,
Enabled: true,
},
expectedStatus: http.StatusBadRequest,
},
}

View File

@@ -35,6 +35,7 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
http2 "github.com/netbirdio/netbird/management/server/http"
@@ -87,22 +88,22 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
jobManager := job.NewJobManager(nil, store, peersManager)
ctx := context.Background()
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatalf("Failed to create cache store: %v", err)
}
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{})
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create PKCE verifier store: %v", err)
}
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore)
pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore)
noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {
@@ -216,22 +217,22 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
jobManager := job.NewJobManager(nil, store, peersManager)
ctx := context.Background()
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatalf("Failed to create cache store: %v", err)
}
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{})
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create PKCE verifier store: %v", err)
}
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore)
pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore)
noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {

View File

@@ -8,6 +8,7 @@ import (
"net/http/httptest"
"path/filepath"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
@@ -19,6 +20,7 @@ import (
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
@@ -83,10 +85,15 @@ func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update
permissionsManager := permissions.NewManager(testStore)
peersManager := peers.NewManager(testStore, permissionsManager)
cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, nil, err
}
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, testStore)
networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peersManager), &config.Config{})
manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, nil, err
}

View File

@@ -66,14 +66,14 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient
}
// Create a new Admin SDK Directory service client
adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey)
credentialsOption, err := getGoogleCredentialsOption(ctx, config.ServiceAccountKey)
if err != nil {
return nil, err
}
service, err := admin.NewService(context.Background(),
option.WithScopes(admin.AdminDirectoryUserReadonlyScope),
option.WithCredentials(adminCredentials),
credentialsOption,
)
if err != nil {
return nil, err
@@ -218,39 +218,32 @@ func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) e
return nil
}
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
// If that fails, it falls back to using the default Google credentials path.
// It returns the retrieved credentials or an error if unsuccessful.
func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) {
// getGoogleCredentialsOption returns the google.golang.org/api option carrying
// Google credentials derived from the provided serviceAccountKey.
// It decodes the base64-encoded serviceAccountKey and uses it as the credentials JSON.
// If the key is empty, it falls back to the default Google credentials path.
func getGoogleCredentialsOption(ctx context.Context, serviceAccountKey string) (option.ClientOption, error) {
log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key")
decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey)
if err != nil {
return nil, fmt.Errorf("failed to decode service account key: %w", err)
}
creds, err := google.CredentialsFromJSON(
context.Background(),
decodeKey,
admin.AdminDirectoryUserReadonlyScope,
)
if err == nil {
// No need to fallback to the default Google credentials path
return creds, nil
if len(decodeKey) > 0 {
return option.WithAuthCredentialsJSON(option.ServiceAccount, decodeKey), nil
}
log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err)
log.WithContext(ctx).Debug("falling back to default google credentials location")
log.WithContext(ctx).Debug("no service account key provided, falling back to default google credentials location")
creds, err = google.FindDefaultCredentials(
context.Background(),
creds, err := google.FindDefaultCredentials(
ctx,
admin.AdminDirectoryUserReadonlyScope,
)
if err != nil {
return nil, err
}
return creds, nil
return option.WithCredentials(creds), nil
}
// parseGoogleWorkspaceUser parse google user to UserData.

View File

@@ -29,6 +29,7 @@ import (
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
@@ -369,9 +370,15 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
requestBuffer := NewAccountRequestBuffer(ctx, store)
ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager))
cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
cleanup()
return nil, nil, "", cleanup, err
}
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config)
accountManager, err := BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
cleanup()

View File

@@ -28,6 +28,7 @@ import (
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
@@ -207,6 +208,12 @@ func startServer(
jobManager := job.NewJobManager(nil, str, peersManager)
ctx := context.Background()
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatalf("failed creating cache store: %v", err)
}
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(ctx, str)
networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config)
@@ -227,7 +234,8 @@ func startServer(
port_forwarding.NewControllerMock(),
settingsMockManager,
permissionsManager,
false)
false,
cacheStore)
if err != nil {
t.Fatalf("failed creating an account manager: %v", err)
}

View File

@@ -17,6 +17,7 @@ import (
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -794,11 +795,17 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
peersManager := peers.NewManager(store, permissionsManager)
ctx := context.Background()
cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, err
}
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
}
func createNSStore(t *testing.T) (store.Store, error) {

View File

@@ -21,11 +21,7 @@ type NetworkRouter struct {
}
func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int, enabled bool) (*NetworkRouter, error) {
if peer != "" && len(peerGroups) > 0 {
return nil, errors.New("peer and peerGroups cannot be set at the same time")
}
return &NetworkRouter{
r := &NetworkRouter{
ID: xid.New().String(),
AccountID: accountID,
NetworkID: networkID,
@@ -34,7 +30,25 @@ func NewNetworkRouter(accountID string, networkID string, peer string, peerGroup
Masquerade: masquerade,
Metric: metric,
Enabled: enabled,
}, nil
}
if err := r.Validate(); err != nil {
return nil, err
}
return r, nil
}
func (n *NetworkRouter) Validate() error {
if n.Peer != "" && len(n.PeerGroups) > 0 {
return errors.New("peer and peer_groups cannot be set at the same time")
}
if n.Peer == "" && len(n.PeerGroups) == 0 {
return errors.New("either peer or peer_groups must be provided")
}
return nil
}
func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter {

View File

@@ -38,7 +38,7 @@ func TestNewNetworkRouter(t *testing.T) {
expectedError: false,
},
{
name: "Valid with no peer or peerGroups",
name: "Invalid with no peer or peerGroups",
networkID: "network-3",
accountID: "account-3",
peer: "",
@@ -46,7 +46,18 @@ func TestNewNetworkRouter(t *testing.T) {
masquerade: true,
metric: 300,
enabled: true,
expectedError: false,
expectedError: true,
},
{
name: "Invalid with empty peerGroups slice",
networkID: "network-5",
accountID: "account-5",
peer: "",
peerGroups: []string{},
masquerade: true,
metric: 500,
enabled: true,
expectedError: true,
},
// Invalid cases

View File

@@ -32,6 +32,7 @@ import (
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
@@ -1318,11 +1319,15 @@ func Test_RegisterPeerByUser(t *testing.T) {
peersManager := peers.NewManager(s, permissionsManager)
ctx := context.Background()
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1405,11 +1410,15 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
peersManager := peers.NewManager(s, permissionsManager)
ctx := context.Background()
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1560,11 +1569,15 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
peersManager := peers.NewManager(s, permissionsManager)
ctx := context.Background()
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1642,11 +1655,15 @@ func Test_LoginPeer(t *testing.T) {
peersManager := peers.NewManager(s, permissionsManager)
ctx := context.Background()
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
require.NoError(t, err)
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, s)
networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"

View File

@@ -19,6 +19,7 @@ import (
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -1292,11 +1293,17 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel.
peersManager := peers.NewManager(store, permissionsManager)
ctx := context.Background()
cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, nil, err
}
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{})
am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, nil, err
}

View File

@@ -1018,10 +1018,10 @@ func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
// GetCustomDomainsCounts returns the total and validated custom domain counts.
func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
var total, validated int64
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Count(&total).Error; err != nil {
if err := s.db.Model(&domain.Domain{}).Count(&total).Error; err != nil {
return 0, 0, err
}
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
if err := s.db.Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
return 0, 0, err
}
return total, validated, nil
@@ -4457,7 +4457,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value.
func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -4476,7 +4476,7 @@ func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStr
// GetAllProxyAccessTokens retrieves all proxy access tokens.
func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -4492,7 +4492,7 @@ func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength Loc
// SaveProxyAccessToken saves a proxy access token to the database.
func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error {
if result := s.db.WithContext(ctx).Create(token); result.Error != nil {
if result := s.db.Create(token); result.Error != nil {
return status.Errorf(status.Internal, "save proxy access token: %v", result.Error)
}
return nil
@@ -4500,7 +4500,7 @@ func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyA
// RevokeProxyAccessToken revokes a proxy access token by its ID.
func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
result := s.db.Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
if result.Error != nil {
return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error)
}
@@ -4514,7 +4514,7 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
result := s.db.Model(&types.ProxyAccessToken{}).
Where(idQueryCondition, tokenID).
Update("last_used", time.Now().UTC())
if result.Error != nil {
@@ -5204,7 +5204,7 @@ func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength Lock
// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port.
func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -5220,7 +5220,7 @@ func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength
// GetServicesByCluster returns all services for the given proxy cluster.
func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -5330,7 +5330,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
var logs []*accesslogs.AccessLogEntry
var totalCount int64
baseQuery := s.db.WithContext(ctx).
baseQuery := s.db.
Model(&accesslogs.AccessLogEntry{}).
Where(accountIDCondition, accountID)
@@ -5341,7 +5341,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
return nil, 0, status.Errorf(status.Internal, "failed to count access logs")
}
query := s.db.WithContext(ctx).
query := s.db.
Where(accountIDCondition, accountID)
query = s.applyAccessLogFilters(query, filter)
@@ -5378,7 +5378,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
// DeleteOldAccessLogs deletes all access logs older than the specified time
func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) {
result := s.db.WithContext(ctx).
result := s.db.
Where("timestamp < ?", olderThan).
Delete(&accesslogs.AccessLogEntry{})
@@ -5467,7 +5467,7 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
// SaveProxy saves or updates a proxy in the database
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
result := s.db.WithContext(ctx).Save(p)
result := s.db.Save(p)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to save proxy")
@@ -5479,7 +5479,7 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
result := s.db.WithContext(ctx).
result := s.db.
Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", now)
@@ -5498,7 +5498,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
ConnectedAt: &now,
Status: "connected",
}
if err := s.db.WithContext(ctx).Save(p).Error; err != nil {
if err := s.db.Save(p).Error; err != nil {
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
}
@@ -5511,7 +5511,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
var addresses []string
result := s.db.WithContext(ctx).
result := s.db.
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Distinct("cluster_address").
@@ -5550,6 +5550,7 @@ const proxyActiveThreshold = 2 * time.Minute
var validCapabilityColumns = map[string]struct{}{
"supports_custom_ports": {},
"require_subdomain": {},
"supports_crowdsec": {},
}
// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster
@@ -5564,6 +5565,59 @@ func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr s
return s.getClusterCapability(ctx, clusterAddr, "require_subdomain")
}
// GetClusterSupportsCrowdSec returns whether all active proxies in the cluster
// have CrowdSec configured. Returns nil when no proxy reported the capability.
// Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec
// requires unanimous support: a single unconfigured proxy would let requests
// bypass reputation checks.
func (s *SqlStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
return s.getClusterUnanimousCapability(ctx, clusterAddr, "supports_crowdsec")
}
// getClusterUnanimousCapability returns an aggregated boolean capability
// requiring all active proxies in the cluster to report true.
func (s *SqlStore) getClusterUnanimousCapability(ctx context.Context, clusterAddr, column string) *bool {
if _, ok := validCapabilityColumns[column]; !ok {
log.WithContext(ctx).Errorf("invalid capability column: %s", column)
return nil
}
var result struct {
Total int64
Reported int64
AllTrue bool
}
// All active proxies must have reported the capability (no NULLs) and all
// must report true. A single unreported or false proxy means the cluster
// does not unanimously support the capability.
err := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Select("COUNT(*) AS total, "+
"COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) AS reported, "+
"COUNT(*) > 0 AND COUNT(*) = COUNT(CASE WHEN "+column+" = true THEN 1 END) AS all_true").
Where("cluster_address = ? AND status = ? AND last_seen > ?",
clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)).
Scan(&result).Error
if err != nil {
log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err)
return nil
}
if result.Total == 0 || result.Reported == 0 {
return nil
}
// If any proxy has not reported (NULL), we can't confirm unanimous support.
if result.Reported < result.Total {
v := false
return &v
}
return &result.AllTrue
}
// getClusterCapability returns an aggregated boolean capability for the given
// cluster. It checks active (connected, recently seen) proxies and returns:
// - *true if any proxy in the cluster has the capability set to true,
@@ -5580,7 +5634,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
AnyTrue bool
}
err := s.db.WithContext(ctx).
err := s.db.
Model(&proxy.Proxy{}).
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").
@@ -5604,7 +5658,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)
result := s.db.WithContext(ctx).
result := s.db.
Where("last_seen < ?", cutoffTime).
Delete(&proxy.Proxy{})

View File

@@ -290,6 +290,7 @@ type Store interface {
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)

View File

@@ -166,6 +166,19 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
}
// GetClusterSupportsCrowdSec mocks base method.
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
}
// Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper()

View File

@@ -183,7 +183,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
w := WrapResponseWriter(rw)
handlerDone := make(chan struct{})
context.AfterFunc(ctx, func() {
select {
case <-handlerDone:
default:
log.Debugf("HTTP request context canceled mid-flight: %v %v (reqID=%s, after %v, cause: %v)",
r.Method, r.URL.Path, reqID, time.Since(reqStart), context.Cause(ctx))
}
})
h.ServeHTTP(w, r.WithContext(ctx))
close(handlerDone)
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
if err == nil {