This commit is contained in:
62
internal/security/ratelimit.go
Normal file
62
internal/security/ratelimit.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type bucket struct {
|
||||
tokens float64
|
||||
last time.Time
|
||||
}
|
||||
|
||||
type Limiter struct {
|
||||
mu sync.Mutex
|
||||
m map[string]*bucket
|
||||
rps float64
|
||||
burst float64
|
||||
}
|
||||
|
||||
func NewLimiter(rps, burst float64) *Limiter {
|
||||
return &Limiter{m: map[string]*bucket{}, rps: rps, burst: burst}
|
||||
}
|
||||
|
||||
func (l *Limiter) allow(key string) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
b := l.m[key]
|
||||
now := time.Now()
|
||||
if b == nil {
|
||||
b = &bucket{tokens: l.burst, last: now}
|
||||
l.m[key] = b
|
||||
}
|
||||
elapsed := now.Sub(b.last).Seconds()
|
||||
b.tokens = min(l.burst, b.tokens+elapsed*l.rps)
|
||||
b.last = now
|
||||
if b.tokens >= 1 {
|
||||
b.tokens -= 1
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (l *Limiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
key := host + r.URL.Path
|
||||
if !l.allow(key) {
|
||||
http.Error(w, "rate limit", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func min(a, b float64) float64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
Reference in New Issue
Block a user