mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 08:46:38 +00:00
add rate limiting for callback endpoint
This commit is contained in:
@@ -2,8 +2,11 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
@@ -11,6 +14,7 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
)
|
||||
@@ -18,12 +22,21 @@ import (
|
||||
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
|
||||
type AuthCallbackHandler struct {
|
||||
proxyService *nbgrpc.ProxyServiceServer
|
||||
rateLimiter *middleware.APIRateLimiter
|
||||
}
|
||||
|
||||
// NewAuthCallbackHandler creates a new OAuth callback handler.
|
||||
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler {
|
||||
rateLimiterConfig := &middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: 10,
|
||||
Burst: 15,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
return &AuthCallbackHandler{
|
||||
proxyService: proxyService,
|
||||
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +46,13 @@ func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
|
||||
}
|
||||
|
||||
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := getClientIP(r)
|
||||
if !h.rateLimiter.Allow(clientIP) {
|
||||
log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded")
|
||||
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
codeVerifier, originalURL, err := h.proxyService.ValidateState(state)
|
||||
@@ -128,3 +148,24 @@ func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config
|
||||
|
||||
return claims.Subject
|
||||
}
|
||||
|
||||
// getClientIP extracts the client IP address from the request.
|
||||
func getClientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return xff
|
||||
}
|
||||
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user