81 lines
1.6 KiB
Go
81 lines
1.6 KiB
Go
package security
|
|
|
|
import (
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type bucket struct {
|
|
tokens float64
|
|
last time.Time
|
|
blocked time.Time
|
|
}
|
|
|
|
type RateLimiter struct {
|
|
mu sync.Mutex
|
|
capacity float64
|
|
refillPer float64 // tokens/sec
|
|
ttl time.Duration
|
|
buckets map[string]*bucket
|
|
}
|
|
|
|
func NewRateLimiter(capacity int, refillPerSec float64, ttl time.Duration) *RateLimiter {
|
|
return &RateLimiter{
|
|
capacity: float64(capacity),
|
|
refillPer: refillPerSec,
|
|
ttl: ttl,
|
|
buckets: map[string]*bucket{},
|
|
}
|
|
}
|
|
|
|
func (rl *RateLimiter) Allow(key string) bool {
|
|
now := time.Now()
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
b := rl.buckets[key]
|
|
if b == nil {
|
|
b = &bucket{tokens: rl.capacity, last: now}
|
|
rl.buckets[key] = b
|
|
}
|
|
// cleanup occasionally
|
|
if len(rl.buckets) > 10000 {
|
|
for k, v := range rl.buckets {
|
|
if now.Sub(v.last) > rl.ttl {
|
|
delete(rl.buckets, k)
|
|
}
|
|
}
|
|
}
|
|
|
|
elapsed := now.Sub(b.last).Seconds()
|
|
if elapsed > 0 {
|
|
b.tokens += elapsed * rl.refillPer
|
|
if b.tokens > rl.capacity {
|
|
b.tokens = rl.capacity
|
|
}
|
|
b.last = now
|
|
}
|
|
if b.tokens >= 1 {
|
|
b.tokens -= 1
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (rl *RateLimiter) Middleware(keyFn func(r *http.Request) string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
key := keyFn(r)
|
|
if key == "" {
|
|
key = "anon"
|
|
}
|
|
if !rl.Allow(key) {
|
|
w.Header().Set("Retry-After", "2")
|
|
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|