mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 08:46:38 +00:00
[management] add pat rate limiting (#4741)
This commit is contained in:
@@ -29,6 +29,7 @@ type AuthMiddleware struct {
|
||||
ensureAccount EnsureAccountFunc
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
rateLimiter *APIRateLimiter
|
||||
}
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
@@ -37,12 +38,19 @@ func NewAuthMiddleware(
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiterConfig *RateLimiterConfig,
|
||||
) *AuthMiddleware {
|
||||
var rateLimiter *APIRateLimiter
|
||||
if rateLimiterConfig != nil {
|
||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
authManager: authManager,
|
||||
ensureAccount: ensureAccount,
|
||||
syncUserJWTGroups: syncUserJWTGroups,
|
||||
getUserFromUserAuth: getUserFromUserAuth,
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
request, err := m.checkPATFromRequest(r, auth)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
if _, ok := status.FromError(err); !ok {
|
||||
err = status.Errorf(status.Unauthorized, "token invalid")
|
||||
}
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, request)
|
||||
@@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user