mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management] record pat usage metrics (#4888)
This commit is contained in:
@@ -105,6 +105,7 @@ func NewAPIHandler(
|
||||
accountManager.SyncUserJWTGroups,
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimitingConfig,
|
||||
appMetrics.GetMeter(),
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
@@ -31,6 +32,7 @@ type AuthMiddleware struct {
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
rateLimiter *APIRateLimiter
|
||||
patUsageTracker *PATUsageTracker
|
||||
}
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
@@ -40,18 +42,29 @@ func NewAuthMiddleware(
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiterConfig *RateLimiterConfig,
|
||||
meter metric.Meter,
|
||||
) *AuthMiddleware {
|
||||
var rateLimiter *APIRateLimiter
|
||||
if rateLimiterConfig != nil {
|
||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||
}
|
||||
|
||||
var patUsageTracker *PATUsageTracker
|
||||
if meter != nil {
|
||||
var err error
|
||||
patUsageTracker, err = NewPATUsageTracker(context.Background(), meter)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to create PAT usage tracker: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
authManager: authManager,
|
||||
ensureAccount: ensureAccount,
|
||||
syncUserJWTGroups: syncUserJWTGroups,
|
||||
getUserFromUserAuth: getUserFromUserAuth,
|
||||
rateLimiter: rateLimiter,
|
||||
patUsageTracker: patUsageTracker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,6 +171,10 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
if m.patUsageTracker != nil {
|
||||
m.patUsageTracker.IncrementUsage(token)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
|
||||
@@ -208,6 +208,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@@ -266,6 +267,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -317,6 +319,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -359,6 +362,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -402,6 +406,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -465,6 +470,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -581,6 +587,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
87
management/server/http/middleware/pat_usage_tracker.go
Normal file
87
management/server/http/middleware/pat_usage_tracker.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"maps"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
// PATUsageTracker tracks PAT usage metrics
|
||||
type PATUsageTracker struct {
|
||||
usageCounters map[string]int64
|
||||
mu sync.Mutex
|
||||
stopChan chan struct{}
|
||||
ctx context.Context
|
||||
histogram metric.Int64Histogram
|
||||
}
|
||||
|
||||
// NewPATUsageTracker creates a new PAT usage tracker with metrics
|
||||
func NewPATUsageTracker(ctx context.Context, meter metric.Meter) (*PATUsageTracker, error) {
|
||||
histogram, err := meter.Int64Histogram(
|
||||
"management.pat.usage_distribution",
|
||||
metric.WithUnit("1"),
|
||||
metric.WithDescription("Distribution of PAT token usage counts per minute"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tracker := &PATUsageTracker{
|
||||
usageCounters: make(map[string]int64),
|
||||
stopChan: make(chan struct{}),
|
||||
ctx: ctx,
|
||||
histogram: histogram,
|
||||
}
|
||||
|
||||
go tracker.reportLoop()
|
||||
|
||||
return tracker, nil
|
||||
}
|
||||
|
||||
// IncrementUsage increments the usage counter for a given token
|
||||
func (t *PATUsageTracker) IncrementUsage(token string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.usageCounters[token]++
|
||||
}
|
||||
|
||||
// reportLoop reports the usage buckets every minute
|
||||
func (t *PATUsageTracker) reportLoop() {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
t.reportUsageBuckets()
|
||||
case <-t.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// reportUsageBuckets reports all token usage counts and resets counters
|
||||
func (t *PATUsageTracker) reportUsageBuckets() {
|
||||
t.mu.Lock()
|
||||
snapshot := maps.Clone(t.usageCounters)
|
||||
|
||||
clear(t.usageCounters)
|
||||
t.mu.Unlock()
|
||||
|
||||
totalTokens := len(snapshot)
|
||||
if totalTokens > 0 {
|
||||
for _, count := range snapshot {
|
||||
t.histogram.Record(t.ctx, count)
|
||||
}
|
||||
log.Debugf("PAT usage in last minute: %d unique tokens used", totalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the reporting goroutine
|
||||
func (t *PATUsageTracker) Stop() {
|
||||
close(t.stopChan)
|
||||
}
|
||||
Reference in New Issue
Block a user