Files
ntfywui/internal/security/ratelimit.go
2026-01-12 13:51:52 +01:00

81 lines
1.6 KiB
Go

package security
import (
"net/http"
"sync"
"time"
)
type bucket struct {
tokens float64
last time.Time
blocked time.Time
}
type RateLimiter struct {
mu sync.Mutex
capacity float64
refillPer float64 // tokens/sec
ttl time.Duration
buckets map[string]*bucket
}
func NewRateLimiter(capacity int, refillPerSec float64, ttl time.Duration) *RateLimiter {
return &RateLimiter{
capacity: float64(capacity),
refillPer: refillPerSec,
ttl: ttl,
buckets: map[string]*bucket{},
}
}
func (rl *RateLimiter) Allow(key string) bool {
now := time.Now()
rl.mu.Lock()
defer rl.mu.Unlock()
b := rl.buckets[key]
if b == nil {
b = &bucket{tokens: rl.capacity, last: now}
rl.buckets[key] = b
}
// cleanup occasionally
if len(rl.buckets) > 10000 {
for k, v := range rl.buckets {
if now.Sub(v.last) > rl.ttl {
delete(rl.buckets, k)
}
}
}
elapsed := now.Sub(b.last).Seconds()
if elapsed > 0 {
b.tokens += elapsed * rl.refillPer
if b.tokens > rl.capacity {
b.tokens = rl.capacity
}
b.last = now
}
if b.tokens >= 1 {
b.tokens -= 1
return true
}
return false
}
func (rl *RateLimiter) Middleware(keyFn func(r *http.Request) string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := keyFn(r)
if key == "" {
key = "anon"
}
if !rl.Allow(key) {
w.Header().Set("Retry-After", "2")
http.Error(w, "rate limited", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}