mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-03 23:56:38 +00:00
add rate limiting for callback endpoint
This commit is contained in:
@@ -2,8 +2,11 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
@@ -11,6 +14,7 @@ import (
|
|||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
"github.com/netbirdio/netbird/proxy/auth"
|
"github.com/netbirdio/netbird/proxy/auth"
|
||||||
)
|
)
|
||||||
@@ -18,12 +22,21 @@ import (
|
|||||||
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
|
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
|
||||||
type AuthCallbackHandler struct {
|
type AuthCallbackHandler struct {
|
||||||
proxyService *nbgrpc.ProxyServiceServer
|
proxyService *nbgrpc.ProxyServiceServer
|
||||||
|
rateLimiter *middleware.APIRateLimiter
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthCallbackHandler creates a new OAuth callback handler.
|
// NewAuthCallbackHandler creates a new OAuth callback handler.
|
||||||
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler {
|
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler {
|
||||||
|
rateLimiterConfig := &middleware.RateLimiterConfig{
|
||||||
|
RequestsPerMinute: 10,
|
||||||
|
Burst: 15,
|
||||||
|
CleanupInterval: 5 * time.Minute,
|
||||||
|
LimiterTTL: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
return &AuthCallbackHandler{
|
return &AuthCallbackHandler{
|
||||||
proxyService: proxyService,
|
proxyService: proxyService,
|
||||||
|
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -33,6 +46,13 @@ func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
clientIP := getClientIP(r)
|
||||||
|
if !h.rateLimiter.Allow(clientIP) {
|
||||||
|
log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded")
|
||||||
|
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
state := r.URL.Query().Get("state")
|
state := r.URL.Query().Get("state")
|
||||||
|
|
||||||
codeVerifier, originalURL, err := h.proxyService.ValidateState(state)
|
codeVerifier, originalURL, err := h.proxyService.ValidateState(state)
|
||||||
@@ -128,3 +148,24 @@ func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config
|
|||||||
|
|
||||||
return claims.Subject
|
return claims.Subject
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getClientIP extracts the client IP address from the request.
|
||||||
|
func getClientIP(r *http.Request) string {
|
||||||
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
if idx := strings.Index(xff, ","); idx != -1 {
|
||||||
|
return strings.TrimSpace(xff[:idx])
|
||||||
|
}
|
||||||
|
return xff
|
||||||
|
}
|
||||||
|
|
||||||
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||||
|
return xri
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to RemoteAddr
|
||||||
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return r.RemoteAddr
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
|||||||
156
management/server/http/handlers/proxy/auth_test.go
Normal file
156
management/server/http/handlers/proxy/auth_test.go
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthCallbackHandler_RateLimiting(t *testing.T) {
|
||||||
|
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
|
||||||
|
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
|
||||||
|
req.RemoteAddr = "192.168.1.100:12345"
|
||||||
|
|
||||||
|
t.Run("allows requests under limit", func(t *testing.T) {
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
allowed := handler.rateLimiter.Allow("192.168.1.100")
|
||||||
|
assert.True(t, allowed, "Request %d should be allowed", i+1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("blocks requests over limit", func(t *testing.T) {
|
||||||
|
handler.rateLimiter.Reset("192.168.1.200")
|
||||||
|
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
handler.rateLimiter.Allow("192.168.1.200")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := handler.rateLimiter.Allow("192.168.1.200")
|
||||||
|
assert.False(t, allowed, "Request over limit should be blocked")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("different IPs have separate limits", func(t *testing.T) {
|
||||||
|
ip1 := "192.168.1.201"
|
||||||
|
ip2 := "192.168.1.202"
|
||||||
|
|
||||||
|
handler.rateLimiter.Reset(ip1)
|
||||||
|
handler.rateLimiter.Reset(ip2)
|
||||||
|
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
handler.rateLimiter.Allow(ip1)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.False(t, handler.rateLimiter.Allow(ip1), "IP1 should be blocked")
|
||||||
|
|
||||||
|
assert.True(t, handler.rateLimiter.Allow(ip2), "IP2 should be allowed")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallbackHandler_RateLimitInHandleCallback(t *testing.T) {
|
||||||
|
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
|
||||||
|
testIP := "10.0.0.50"
|
||||||
|
|
||||||
|
handler.rateLimiter.Reset(testIP)
|
||||||
|
|
||||||
|
t.Run("returns 429 when rate limited", func(t *testing.T) {
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
handler.rateLimiter.Allow(testIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/callback?state=test&code=test", nil)
|
||||||
|
req.RemoteAddr = testIP + ":12345"
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.handleCallback(rr, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, rr.Code, "Should return 429 status code")
|
||||||
|
assert.Contains(t, rr.Body.String(), "Too many requests", "Should contain rate limit message")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClientIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
remoteAddr string
|
||||||
|
xForwardedFor string
|
||||||
|
xRealIP string
|
||||||
|
expectedIP string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extract from RemoteAddr",
|
||||||
|
remoteAddr: "192.168.1.100:12345",
|
||||||
|
expectedIP: "192.168.1.100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extract from X-Forwarded-For single IP",
|
||||||
|
remoteAddr: "10.0.0.1:54321",
|
||||||
|
xForwardedFor: "203.0.113.195",
|
||||||
|
expectedIP: "203.0.113.195",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extract from X-Forwarded-For multiple IPs",
|
||||||
|
remoteAddr: "10.0.0.1:54321",
|
||||||
|
xForwardedFor: "203.0.113.195, 70.41.3.18, 150.172.238.178",
|
||||||
|
expectedIP: "203.0.113.195",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extract from X-Real-IP",
|
||||||
|
remoteAddr: "10.0.0.1:54321",
|
||||||
|
xRealIP: "198.51.100.42",
|
||||||
|
expectedIP: "198.51.100.42",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "X-Forwarded-For takes precedence over X-Real-IP",
|
||||||
|
remoteAddr: "10.0.0.1:54321",
|
||||||
|
xForwardedFor: "203.0.113.195",
|
||||||
|
xRealIP: "198.51.100.42",
|
||||||
|
expectedIP: "203.0.113.195",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "handle RemoteAddr without port",
|
||||||
|
remoteAddr: "192.168.1.100",
|
||||||
|
expectedIP: "192.168.1.100",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||||
|
req.RemoteAddr = tt.remoteAddr
|
||||||
|
|
||||||
|
if tt.xForwardedFor != "" {
|
||||||
|
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||||
|
}
|
||||||
|
if tt.xRealIP != "" {
|
||||||
|
req.Header.Set("X-Real-IP", tt.xRealIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := getClientIP(req)
|
||||||
|
assert.Equal(t, tt.expectedIP, ip, "Extracted IP should match expected")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCallbackHandler_RateLimiterConfiguration(t *testing.T) {
|
||||||
|
handler := NewAuthCallbackHandler(&nbgrpc.ProxyServiceServer{})
|
||||||
|
|
||||||
|
require.NotNil(t, handler.rateLimiter, "Rate limiter should be initialized")
|
||||||
|
|
||||||
|
testIP := "192.168.1.250"
|
||||||
|
handler.rateLimiter.Reset(testIP)
|
||||||
|
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
allowed := handler.rateLimiter.Allow(testIP)
|
||||||
|
assert.True(t, allowed, "Should allow request %d within burst limit", i+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := handler.rateLimiter.Allow(testIP)
|
||||||
|
assert.False(t, allowed, "Should block request that exceeds burst limit")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user