63 lines
1.1 KiB
Go
63 lines
1.1 KiB
Go
package security
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type bucket struct {
|
|
tokens float64
|
|
last time.Time
|
|
}
|
|
|
|
type Limiter struct {
|
|
mu sync.Mutex
|
|
m map[string]*bucket
|
|
rps float64
|
|
burst float64
|
|
}
|
|
|
|
func NewLimiter(rps, burst float64) *Limiter {
|
|
return &Limiter{m: map[string]*bucket{}, rps: rps, burst: burst}
|
|
}
|
|
|
|
func (l *Limiter) allow(key string) bool {
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
b := l.m[key]
|
|
now := time.Now()
|
|
if b == nil {
|
|
b = &bucket{tokens: l.burst, last: now}
|
|
l.m[key] = b
|
|
}
|
|
elapsed := now.Sub(b.last).Seconds()
|
|
b.tokens = min(l.burst, b.tokens+elapsed*l.rps)
|
|
b.last = now
|
|
if b.tokens >= 1 {
|
|
b.tokens -= 1
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (l *Limiter) Middleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
|
key := host + r.URL.Path
|
|
if !l.allow(key) {
|
|
http.Error(w, "rate limit", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func min(a, b float64) float64 {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|