mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
[management] record pat usage metrics (#4888)
This commit is contained in:
@@ -105,6 +105,7 @@ func NewAPIHandler(
|
|||||||
accountManager.SyncUserJWTGroups,
|
accountManager.SyncUserJWTGroups,
|
||||||
accountManager.GetUserFromUserAuth,
|
accountManager.GetUserFromUserAuth,
|
||||||
rateLimitingConfig,
|
rateLimitingConfig,
|
||||||
|
appMetrics.GetMeter(),
|
||||||
)
|
)
|
||||||
|
|
||||||
corsMiddleware := cors.AllowAll()
|
corsMiddleware := cors.AllowAll()
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
serverauth "github.com/netbirdio/netbird/management/server/auth"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
@@ -31,6 +32,7 @@ type AuthMiddleware struct {
|
|||||||
getUserFromUserAuth GetUserFromUserAuthFunc
|
getUserFromUserAuth GetUserFromUserAuthFunc
|
||||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||||
rateLimiter *APIRateLimiter
|
rateLimiter *APIRateLimiter
|
||||||
|
patUsageTracker *PATUsageTracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthMiddleware instance constructor
|
// NewAuthMiddleware instance constructor
|
||||||
@@ -40,18 +42,29 @@ func NewAuthMiddleware(
|
|||||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||||
rateLimiterConfig *RateLimiterConfig,
|
rateLimiterConfig *RateLimiterConfig,
|
||||||
|
meter metric.Meter,
|
||||||
) *AuthMiddleware {
|
) *AuthMiddleware {
|
||||||
var rateLimiter *APIRateLimiter
|
var rateLimiter *APIRateLimiter
|
||||||
if rateLimiterConfig != nil {
|
if rateLimiterConfig != nil {
|
||||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var patUsageTracker *PATUsageTracker
|
||||||
|
if meter != nil {
|
||||||
|
var err error
|
||||||
|
patUsageTracker, err = NewPATUsageTracker(context.Background(), meter)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to create PAT usage tracker: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &AuthMiddleware{
|
return &AuthMiddleware{
|
||||||
authManager: authManager,
|
authManager: authManager,
|
||||||
ensureAccount: ensureAccount,
|
ensureAccount: ensureAccount,
|
||||||
syncUserJWTGroups: syncUserJWTGroups,
|
syncUserJWTGroups: syncUserJWTGroups,
|
||||||
getUserFromUserAuth: getUserFromUserAuth,
|
getUserFromUserAuth: getUserFromUserAuth,
|
||||||
rateLimiter: rateLimiter,
|
rateLimiter: rateLimiter,
|
||||||
|
patUsageTracker: patUsageTracker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,6 +171,10 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
return r, fmt.Errorf("error extracting token: %w", err)
|
return r, fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.patUsageTracker != nil {
|
||||||
|
m.patUsageTracker.IncrementUsage(token)
|
||||||
|
}
|
||||||
|
|
||||||
if m.rateLimiter != nil {
|
if m.rateLimiter != nil {
|
||||||
if !m.rateLimiter.Allow(token) {
|
if !m.rateLimiter.Allow(token) {
|
||||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||||
|
|||||||
@@ -208,6 +208,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||||
@@ -266,6 +267,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
rateLimitConfig,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -317,6 +319,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
rateLimitConfig,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -359,6 +362,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
rateLimitConfig,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -402,6 +406,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
rateLimitConfig,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -465,6 +470,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
rateLimitConfig,
|
rateLimitConfig,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -581,6 +587,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
return &types.User{}, nil
|
return &types.User{}, nil
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
|
|||||||
87
management/server/http/middleware/pat_usage_tracker.go
Normal file
87
management/server/http/middleware/pat_usage_tracker.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"maps"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PATUsageTracker tracks PAT usage metrics
|
||||||
|
type PATUsageTracker struct {
|
||||||
|
usageCounters map[string]int64
|
||||||
|
mu sync.Mutex
|
||||||
|
stopChan chan struct{}
|
||||||
|
ctx context.Context
|
||||||
|
histogram metric.Int64Histogram
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPATUsageTracker creates a new PAT usage tracker with metrics
|
||||||
|
func NewPATUsageTracker(ctx context.Context, meter metric.Meter) (*PATUsageTracker, error) {
|
||||||
|
histogram, err := meter.Int64Histogram(
|
||||||
|
"management.pat.usage_distribution",
|
||||||
|
metric.WithUnit("1"),
|
||||||
|
metric.WithDescription("Distribution of PAT token usage counts per minute"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker := &PATUsageTracker{
|
||||||
|
usageCounters: make(map[string]int64),
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
ctx: ctx,
|
||||||
|
histogram: histogram,
|
||||||
|
}
|
||||||
|
|
||||||
|
go tracker.reportLoop()
|
||||||
|
|
||||||
|
return tracker, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementUsage increments the usage counter for a given token
|
||||||
|
func (t *PATUsageTracker) IncrementUsage(token string) {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
t.usageCounters[token]++
|
||||||
|
}
|
||||||
|
|
||||||
|
// reportLoop reports the usage buckets every minute
|
||||||
|
func (t *PATUsageTracker) reportLoop() {
|
||||||
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
t.reportUsageBuckets()
|
||||||
|
case <-t.stopChan:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reportUsageBuckets reports all token usage counts and resets counters
|
||||||
|
func (t *PATUsageTracker) reportUsageBuckets() {
|
||||||
|
t.mu.Lock()
|
||||||
|
snapshot := maps.Clone(t.usageCounters)
|
||||||
|
|
||||||
|
clear(t.usageCounters)
|
||||||
|
t.mu.Unlock()
|
||||||
|
|
||||||
|
totalTokens := len(snapshot)
|
||||||
|
if totalTokens > 0 {
|
||||||
|
for _, count := range snapshot {
|
||||||
|
t.histogram.Record(t.ctx, count)
|
||||||
|
}
|
||||||
|
log.Debugf("PAT usage in last minute: %d unique tokens used", totalTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the reporting goroutine
|
||||||
|
func (t *PATUsageTracker) Stop() {
|
||||||
|
close(t.stopChan)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user