Add user invite link feature for embedded IdP (#5157)

This commit is contained in:
Misha Bragin
2026-01-27 09:42:20 +01:00
committed by GitHub
parent 44ab454a13
commit 7d791620a6
21 changed files with 4832 additions and 2 deletions

View File

@@ -2,10 +2,14 @@ package middleware
import (
"context"
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
"github.com/netbirdio/netbird/shared/management/http/util"
)
// RateLimiterConfig holds configuration for the API rate limiter
@@ -144,3 +148,25 @@ func (rl *APIRateLimiter) Reset(key string) {
defer rl.mu.Unlock()
delete(rl.limiters, key)
}
// Middleware returns an HTTP middleware that rate limits requests by client IP.
// Returns 429 Too Many Requests if the rate limit is exceeded.
func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
if !rl.Allow(clientIP) {
util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w)
return
}
next.ServeHTTP(w, r)
})
}
// getClientIP extracts the client IP address from the request.
func getClientIP(r *http.Request) string {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return ip
}

View File

@@ -0,0 +1,158 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestAPIRateLimiter_Allow(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// First two requests should be allowed (burst)
assert.True(t, rl.Allow("test-key"))
assert.True(t, rl.Allow("test-key"))
// Third request should be denied (exceeded burst)
assert.False(t, rl.Allow("test-key"))
// Different key should be allowed
assert.True(t, rl.Allow("different-key"))
}
func TestAPIRateLimiter_Middleware(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60, // 1 per second
Burst: 2,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Create a simple handler that returns 200 OK
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Wrap with rate limiter middleware
handler := rl.Middleware(nextHandler)
// First two requests should pass (burst)
for i := 0; i < 2; i++ {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusOK, rr.Code, "request %d should be allowed", i+1)
}
// Third request should be rate limited
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, http.StatusTooManyRequests, rr.Code)
}
func TestAPIRateLimiter_Middleware_DifferentIPs(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler := rl.Middleware(nextHandler)
// Request from first IP
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
req1.RemoteAddr = "192.168.1.1:12345"
rr1 := httptest.NewRecorder()
handler.ServeHTTP(rr1, req1)
assert.Equal(t, http.StatusOK, rr1.Code)
// Second request from first IP should be rate limited
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
req2.RemoteAddr = "192.168.1.1:12345"
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
assert.Equal(t, http.StatusTooManyRequests, rr2.Code)
// Request from different IP should be allowed
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
req3.RemoteAddr = "192.168.1.2:12345"
rr3 := httptest.NewRecorder()
handler.ServeHTTP(rr3, req3)
assert.Equal(t, http.StatusOK, rr3.Code)
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
expected string
}{
{
name: "remote addr with port",
remoteAddr: "192.168.1.1:12345",
expected: "192.168.1.1",
},
{
name: "remote addr without port",
remoteAddr: "192.168.1.1",
expected: "192.168.1.1",
},
{
name: "IPv6 with port",
remoteAddr: "[::1]:12345",
expected: "::1",
},
{
name: "IPv6 without port",
remoteAddr: "::1",
expected: "::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tc.remoteAddr
assert.Equal(t, tc.expected, getClientIP(req))
})
}
}
func TestAPIRateLimiter_Reset(t *testing.T) {
rl := NewAPIRateLimiter(&RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: time.Minute,
LimiterTTL: time.Minute,
})
defer rl.Stop()
// Use up the burst
assert.True(t, rl.Allow("test-key"))
assert.False(t, rl.Allow("test-key"))
// Reset the limiter
rl.Reset("test-key")
// Should be allowed again
assert.True(t, rl.Allow("test-key"))
}