mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Add user invite link feature for embedded IdP (#5157)
This commit is contained in:
@@ -2,10 +2,14 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
@@ -144,3 +148,25 @@ func (rl *APIRateLimiter) Reset(key string) {
|
||||
defer rl.mu.Unlock()
|
||||
delete(rl.limiters, key)
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that rate limits requests by client IP.
|
||||
// Returns 429 Too Many Requests if the rate limit is exceeded.
|
||||
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := getClientIP(r)
|
||||
if !rl.Allow(clientIP) {
|
||||
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP extracts the client IP address from the request.
|
||||
func getClientIP(r *http.Request) string {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
158
management/server/http/middleware/rate_limiter_test.go
Normal file
158
management/server/http/middleware/rate_limiter_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAPIRateLimiter_Allow(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60, // 1 per second
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// First two requests should be allowed (burst)
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
|
||||
// Third request should be denied (exceeded burst)
|
||||
assert.False(t, rl.Allow("test-key"))
|
||||
|
||||
// Different key should be allowed
|
||||
assert.True(t, rl.Allow("different-key"))
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Middleware(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60, // 1 per second
|
||||
Burst: 2,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// Create a simple handler that returns 200 OK
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Wrap with rate limiter middleware
|
||||
handler := rl.Middleware(nextHandler)
|
||||
|
||||
// First two requests should pass (burst)
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1)
|
||||
}
|
||||
|
||||
// Third request should be rate limited
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
handler := rl.Middleware(nextHandler)
|
||||
|
||||
// Request from first IP
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.1:12345"
|
||||
rr1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr1, req1)
|
||||
assert.Equal(t, http.StatusOK, rr1.Code)
|
||||
|
||||
// Second request from first IP should be rate limited
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.1:12345"
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rr2.Code)
|
||||
|
||||
// Request from different IP should be allowed
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req3.RemoteAddr = "192.168.1.2:12345"
|
||||
rr3 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr3, req3)
|
||||
assert.Equal(t, http.StatusOK, rr3.Code)
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "remote addr with port",
|
||||
remoteAddr: "192.168.1.1:12345",
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "remote addr without port",
|
||||
remoteAddr: "192.168.1.1",
|
||||
expected: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 with port",
|
||||
remoteAddr: "[::1]:12345",
|
||||
expected: "::1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 without port",
|
||||
remoteAddr: "::1",
|
||||
expected: "::1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.RemoteAddr = tc.remoteAddr
|
||||
assert.Equal(t, tc.expected, getClientIP(req))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIRateLimiter_Reset(t *testing.T) {
|
||||
rl := NewAPIRateLimiter(&RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: time.Minute,
|
||||
LimiterTTL: time.Minute,
|
||||
})
|
||||
defer rl.Stop()
|
||||
|
||||
// Use up the burst
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
assert.False(t, rl.Allow("test-key"))
|
||||
|
||||
// Reset the limiter
|
||||
rl.Reset("test-key")
|
||||
|
||||
// Should be allowed again
|
||||
assert.True(t, rl.Allow("test-key"))
|
||||
}
|
||||
Reference in New Issue
Block a user