diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 7cf0b5765..b7c6c113c 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -105,6 +105,7 @@ func NewAPIHandler( accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, rateLimitingConfig, + appMetrics.GetMeter(), ) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 9439165a4..ffd7e0b93 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -9,6 +9,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" @@ -31,6 +32,7 @@ type AuthMiddleware struct { getUserFromUserAuth GetUserFromUserAuthFunc syncUserJWTGroups SyncUserJWTGroupsFunc rateLimiter *APIRateLimiter + patUsageTracker *PATUsageTracker } // NewAuthMiddleware instance constructor @@ -40,18 +42,29 @@ func NewAuthMiddleware( syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, rateLimiterConfig *RateLimiterConfig, + meter metric.Meter, ) *AuthMiddleware { var rateLimiter *APIRateLimiter if rateLimiterConfig != nil { rateLimiter = NewAPIRateLimiter(rateLimiterConfig) } + var patUsageTracker *PATUsageTracker + if meter != nil { + var err error + patUsageTracker, err = NewPATUsageTracker(context.Background(), meter) + if err != nil { + log.Errorf("Failed to create PAT usage tracker: %s", err) + } + } + return &AuthMiddleware{ authManager: authManager, ensureAccount: ensureAccount, syncUserJWTGroups: syncUserJWTGroups, getUserFromUserAuth: getUserFromUserAuth, rateLimiter: rateLimiter, + patUsageTracker: patUsageTracker, } } @@ -158,6 +171,10 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] return r, fmt.Errorf("error extracting token: %w", err) } + if m.patUsageTracker != nil { + m.patUsageTracker.IncrementUsage(token) + } + if m.rateLimiter != nil { if !m.rateLimiter.Allow(token) { return r, status.Errorf(status.TooManyRequests, "too many requests") diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 7badc03e4..ba4d16796 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -208,6 +208,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { return &types.User{}, nil }, nil, + nil, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -266,6 +267,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -317,6 +319,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -359,6 +362,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -402,6 +406,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -465,6 +470,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { return &types.User{}, nil }, rateLimitConfig, + nil, ) handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -581,6 +587,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { return &types.User{}, nil }, nil, + nil, ) for _, tc := range tt { diff --git a/management/server/http/middleware/pat_usage_tracker.go b/management/server/http/middleware/pat_usage_tracker.go new file mode 100644 index 000000000..331c288e7 --- /dev/null +++ b/management/server/http/middleware/pat_usage_tracker.go @@ -0,0 +1,87 @@ +package middleware + +import ( + "context" + "maps" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" +) + +// PATUsageTracker tracks PAT usage metrics +type PATUsageTracker struct { + usageCounters map[string]int64 + mu sync.Mutex + stopChan chan struct{} + ctx context.Context + histogram metric.Int64Histogram +} + +// NewPATUsageTracker creates a new PAT usage tracker with metrics +func NewPATUsageTracker(ctx context.Context, meter metric.Meter) (*PATUsageTracker, error) { + histogram, err := meter.Int64Histogram( + "management.pat.usage_distribution", + metric.WithUnit("1"), + metric.WithDescription("Distribution of PAT token usage counts per minute"), + ) + if err != nil { + return nil, err + } + + tracker := &PATUsageTracker{ + usageCounters: make(map[string]int64), + stopChan: make(chan struct{}), + ctx: ctx, + histogram: histogram, + } + + go tracker.reportLoop() + + return tracker, nil +} + +// IncrementUsage increments the usage counter for a given token +func (t *PATUsageTracker) IncrementUsage(token string) { + t.mu.Lock() + defer t.mu.Unlock() + t.usageCounters[token]++ +} + +// reportLoop reports the usage buckets every minute +func (t *PATUsageTracker) reportLoop() { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + t.reportUsageBuckets() + case <-t.stopChan: + return + } + } +} + +// reportUsageBuckets reports all token usage counts and resets counters +func (t *PATUsageTracker) reportUsageBuckets() { + t.mu.Lock() + snapshot := maps.Clone(t.usageCounters) + + clear(t.usageCounters) + t.mu.Unlock() + + totalTokens := len(snapshot) + if totalTokens > 0 { + for _, count := range snapshot { + t.histogram.Record(t.ctx, count) + } + log.Debugf("PAT usage in last minute: %d unique tokens used", totalTokens) + } +} + +// Stop stops the reporting goroutine +func (t *PATUsageTracker) Stop() { + close(t.stopChan) +}