merge main

This commit is contained in:
crn4
2026-04-27 18:03:11 +02:00
198 changed files with 10067 additions and 7036 deletions

View File

@@ -5,9 +5,6 @@ import (
"fmt"
"net/http"
"net/netip"
"os"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/rs/cors"
@@ -67,14 +64,11 @@ import (
)
const (
apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
apiPrefix = "/api"
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -95,34 +89,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
rpm := 6
if v := os.Getenv(rateLimitingRPMKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
} else {
rpm = value
}
}
burst := 500
if v := os.Getenv(rateLimitingBurstKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
} else {
burst = value
}
}
rateLimitingConfig = &middleware.RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}
if rateLimiter == nil {
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
rateLimiter = middleware.NewAPIRateLimiter(nil)
rateLimiter.SetEnabled(false)
}
authMiddleware := middleware.NewAuthMiddleware(
@@ -130,7 +100,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimitingConfig,
rateLimiter,
appMetrics.GetMeter(),
)

View File

@@ -105,6 +105,12 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
router.NetworkID = networkID
router.AccountID = accountID
router.Enabled = true
if err := router.Validate(); err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
router, err = h.routersManager.CreateRouter(r.Context(), userID, router)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -157,6 +163,11 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
router.ID = mux.Vars(r)["routerId"]
router.AccountID = accountID
if err := router.Validate(); err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
router, err = h.routersManager.UpdateRouter(r.Context(), userID, router)
if err != nil {
util.WriteError(r.Context(), err, w)

View File

@@ -417,7 +417,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
netMap := account.GetPeerNetworkMapFromComponents(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}

View File

@@ -22,6 +22,7 @@ import (
nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
@@ -191,11 +192,11 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
oidcServer := newFakeOIDCServer()
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100)
cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore)
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore)
usersManager := users.NewManager(testStore)

View File

@@ -12,6 +12,7 @@ import (
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/management-integrations/integrations"
serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
@@ -42,14 +43,9 @@ func NewAuthMiddleware(
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiterConfig *RateLimiterConfig,
rateLimiter *APIRateLimiter,
meter metric.Meter,
) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
var patUsageTracker *PATUsageTracker
if meter != nil {
var err error
@@ -87,17 +83,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
switch authType {
case "bearer":
request, err := m.checkJWTFromRequest(r, authHeader)
if err != nil {
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, request)
h.ServeHTTP(w, r)
case "token":
request, err := m.checkPATFromRequest(r, authHeader)
if err != nil {
if err := m.checkPATFromRequest(r, authHeader); err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
// Check if it's a status error, otherwise default to Unauthorized
if _, ok := status.FromError(err); !ok {
@@ -106,7 +99,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
util.WriteError(r.Context(), err, w)
return
}
h.ServeHTTP(w, request)
h.ServeHTTP(w, r)
default:
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
@@ -115,19 +108,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromJWTRequest(authHeaderParts)
// If an error occurs, call the error handler and return an error
if err != nil {
return r, fmt.Errorf("error extracting token: %w", err)
return fmt.Errorf("error extracting token: %w", err)
}
ctx := r.Context()
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
if err != nil {
return r, err
return err
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
@@ -143,7 +136,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil {
return r, err
return err
}
if userAuth.AccountId != accountId {
@@ -153,7 +146,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
if err != nil {
return r, err
return err
}
err = m.syncUserJWTGroups(ctx, userAuth)
@@ -164,41 +157,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
_, err = m.getUserFromUserAuth(ctx, userAuth)
if err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
return r, err
return err
}
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
// propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
return nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) {
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromPATRequest(authHeaderParts)
if err != nil {
return r, fmt.Errorf("error extracting token: %w", err)
return fmt.Errorf("error extracting token: %w", err)
}
if m.patUsageTracker != nil {
m.patUsageTracker.IncrementUsage(token)
}
if m.rateLimiter != nil && !isTerraformRequest(r) {
if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests")
}
if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) {
return status.Errorf(status.TooManyRequests, "too many requests")
}
ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil {
return r, fmt.Errorf("invalid Token: %w", err)
return fmt.Errorf("invalid Token: %w", err)
}
if time.Now().After(pat.GetExpirationDate()) {
return r, fmt.Errorf("token expired")
return fmt.Errorf("token expired")
}
err = m.authManager.MarkPATUsed(ctx, pat.ID)
if err != nil {
return r, err
return err
}
userAuth := auth.UserAuth{
@@ -216,7 +209,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
}
}
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
// propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
return nil
}
func isTerraformRequest(r *http.Request) bool {

View File

@@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
disabledLimiter,
nil,
)
@@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
NewAPIRateLimiter(rateLimitConfig),
nil,
)
@@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
disabledLimiter := NewAPIRateLimiter(nil)
disabledLimiter.SetEnabled(false)
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
@@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
disabledLimiter,
nil,
)

View File

@@ -4,14 +4,27 @@ import (
"context"
"net"
"net/http"
"os"
"strconv"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util"
)
const (
RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED"
RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST"
RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM"
defaultAPIRPM = 6
defaultAPIBurst = 500
)
// RateLimiterConfig holds configuration for the API rate limiter
type RateLimiterConfig struct {
// RequestsPerMinute defines the rate at which tokens are replenished
@@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig {
}
}
func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) {
rpm := defaultAPIRPM
if v := os.Getenv(RateLimitingRPMEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm)
} else {
rpm = value
}
}
if rpm <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM)
rpm = defaultAPIRPM
}
burst := defaultAPIBurst
if v := os.Getenv(RateLimitingBurstEnv); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst)
} else {
burst = value
}
}
if burst <= 0 {
log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst)
burst = defaultAPIBurst
}
return &RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}, os.Getenv(RateLimitingEnabledEnv) == "true"
}
// limiterEntry holds a rate limiter and its last access time
type limiterEntry struct {
limiter *rate.Limiter
@@ -46,6 +96,7 @@ type APIRateLimiter struct {
limiters map[string]*limiterEntry
mu sync.RWMutex
stopChan chan struct{}
enabled atomic.Bool
}
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
@@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
limiters: make(map[string]*limiterEntry),
stopChan: make(chan struct{}),
}
rl.enabled.Store(true)
go rl.cleanupLoop()
return rl
}
func (rl *APIRateLimiter) SetEnabled(enabled bool) {
rl.enabled.Store(enabled)
}
func (rl *APIRateLimiter) Enabled() bool {
return rl.enabled.Load()
}
func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) {
if config == nil {
return
}
if config.RequestsPerMinute <= 0 || config.Burst <= 0 {
log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst)
return
}
newRPS := rate.Limit(config.RequestsPerMinute / 60.0)
newBurst := config.Burst
rl.mu.Lock()
rl.config.RequestsPerMinute = config.RequestsPerMinute
rl.config.Burst = newBurst
snapshot := make([]*rate.Limiter, 0, len(rl.limiters))
for _, entry := range rl.limiters {
snapshot = append(snapshot, entry.limiter)
}
rl.mu.Unlock()
for _, l := range snapshot {
l.SetLimit(newRPS)
l.SetBurst(newBurst)
}
}
// Allow checks if a request for the given key (token) is allowed
func (rl *APIRateLimiter) Allow(key string) bool {
if !rl.enabled.Load() {
return true
}
limiter := rl.getLimiter(key)
return limiter.Allow()
}
@@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool {
// Wait blocks until the rate limiter allows another request for the given key
// Returns an error if the context is canceled
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
if !rl.enabled.Load() {
return nil
}
limiter := rl.getLimiter(key)
return limiter.Wait(ctx)
}
@@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) {
// Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !rl.enabled.Load() {
next.ServeHTTP(w, r)
return
}
clientIP := getClientIP(r)
if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)

View File

@@ -1,8 +1,10 @@
package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
@@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) {
// Should be allowed again
assert.True(t, rl.Allow("test-key"))
}
func TestAPIRateLimiter_SetEnabled(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("key"))
assert.False(t, rl.Allow("key"), "burst exhausted while enabled")
rl.SetEnabled(false)
assert.False(t, rl.Enabled())
for i := 0; i < 5; i++ {
assert.True(t, rl.Allow("key"), "disabled limiter must always allow")
}
rl.SetEnabled(true)
assert.True(t, rl.Enabled())
assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state")
}
func TestAPIRateLimiter_UpdateConfig(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k1"))
assert.True(t, rl.Allow("k1"))
assert.False(t, rl.Allow("k1"), "burst=2 exhausted")
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
// New burst applies to existing keys in place; bucket refills up to new burst over time,
// but importantly newly-added keys use the updated config immediately.
assert.True(t, rl.Allow("k2"))
for i := 0; i < 9; i++ {
assert.True(t, rl.Allow("k2"))
}
assert.False(t, rl.Allow("k2"), "new burst=10 exhausted")
}
func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
rl.UpdateConfig(nil) // must not panic or zero the config
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
}
func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"))
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute})
rl.Reset("k")
assert.True(t, rl.Allow("k"))
assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored")
}
func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 600,
Burst: 10,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
var wg sync.WaitGroup
stop := make(chan struct{})
for i := 0; i < 8; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
key := fmt.Sprintf("k%d", id)
for {
select {
case <-stop:
return
default:
rl.Allow(key)
}
}
}(i)
}
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 200; i++ {
select {
case <-stop:
return
default:
rl.UpdateConfig(&RateLimiterConfig{
RequestsPerMinute: float64(30 + (i % 90)),
Burst: 1 + (i % 20),
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
rl.SetEnabled(i%2 == 0)
}
}
}()
time.Sleep(100 * time.Millisecond)
close(stop)
wg.Wait()
}
func TestRateLimiterConfigFromEnv(t *testing.T) {
t.Setenv(RateLimitingEnabledEnv, "true")
t.Setenv(RateLimitingRPMEnv, "42")
t.Setenv(RateLimitingBurstEnv, "7")
cfg, enabled := RateLimiterConfigFromEnv()
assert.True(t, enabled)
assert.Equal(t, float64(42), cfg.RequestsPerMinute)
assert.Equal(t, 7, cfg.Burst)
t.Setenv(RateLimitingEnabledEnv, "false")
_, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
t.Setenv(RateLimitingEnabledEnv, "")
t.Setenv(RateLimitingRPMEnv, "")
t.Setenv(RateLimitingBurstEnv, "")
cfg, enabled = RateLimiterConfigFromEnv()
assert.False(t, enabled)
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute)
assert.Equal(t, defaultAPIBurst, cfg.Burst)
t.Setenv(RateLimitingRPMEnv, "0")
t.Setenv(RateLimitingBurstEnv, "-5")
cfg, _ = RateLimiterConfigFromEnv()
assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default")
assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default")
}

View File

@@ -1170,13 +1170,17 @@ func Test_NetworkRouters_Create(t *testing.T) {
Metric: 100,
Enabled: true,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, router *api.NetworkRouter) {
t.Helper()
assert.NotEmpty(t, router.Id)
assert.Equal(t, peerID, *router.Peer)
assert.Equal(t, 1, len(*router.PeerGroups))
expectedStatus: http.StatusBadRequest,
},
{
name: "Create router without peer and peer_groups",
networkId: "testNetworkId",
requestBody: &api.NetworkRouterRequest{
Masquerade: true,
Metric: 100,
Enabled: true,
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Create router in non-existing network",
@@ -1341,13 +1345,18 @@ func Test_NetworkRouters_Update(t *testing.T) {
Metric: 100,
Enabled: true,
},
expectedStatus: http.StatusOK,
verifyResponse: func(t *testing.T, router *api.NetworkRouter) {
t.Helper()
assert.Equal(t, "testRouterId", router.Id)
assert.Equal(t, peerID, *router.Peer)
assert.Equal(t, 1, len(*router.PeerGroups))
expectedStatus: http.StatusBadRequest,
},
{
name: "Update router without peer and peer_groups",
networkId: "testNetworkId",
routerId: "testRouterId",
requestBody: &api.NetworkRouterRequest{
Masquerade: true,
Metric: 100,
Enabled: true,
},
expectedStatus: http.StatusBadRequest,
},
}

View File

@@ -35,6 +35,7 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
http2 "github.com/netbirdio/netbird/management/server/http"
@@ -87,22 +88,22 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
jobManager := job.NewJobManager(nil, store, peersManager)
ctx := context.Background()
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatalf("Failed to create cache store: %v", err)
}
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{})
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create PKCE verifier store: %v", err)
}
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore)
pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore)
noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {
@@ -134,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -216,22 +217,22 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
jobManager := job.NewJobManager(nil, store, peersManager)
ctx := context.Background()
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatalf("Failed to create cache store: %v", err)
}
requestBuffer := server.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{})
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil)
proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create proxy token store: %v", err)
}
pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
if err != nil {
t.Fatalf("Failed to create PKCE verifier store: %v", err)
}
proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore)
pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore)
noopMeter := noop.NewMeterProvider().Meter("")
proxyMgr, err := proxymanager.NewManager(store, noopMeter)
if err != nil {
@@ -263,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}