mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 18:26:41 +00:00
Validate trusted proxies in OAuth callback getClientIP
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -21,12 +22,13 @@ import (
|
||||
|
||||
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
|
||||
type AuthCallbackHandler struct {
|
||||
proxyService *nbgrpc.ProxyServiceServer
|
||||
rateLimiter *middleware.APIRateLimiter
|
||||
proxyService *nbgrpc.ProxyServiceServer
|
||||
rateLimiter *middleware.APIRateLimiter
|
||||
trustedProxies []netip.Prefix
|
||||
}
|
||||
|
||||
// NewAuthCallbackHandler creates a new OAuth callback handler.
|
||||
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler {
|
||||
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer, trustedProxies []netip.Prefix) *AuthCallbackHandler {
|
||||
rateLimiterConfig := &middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: 10,
|
||||
Burst: 15,
|
||||
@@ -35,8 +37,9 @@ func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallba
|
||||
}
|
||||
|
||||
return &AuthCallbackHandler{
|
||||
proxyService: proxyService,
|
||||
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
|
||||
proxyService: proxyService,
|
||||
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
|
||||
trustedProxies: trustedProxies,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,7 +49,7 @@ func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
|
||||
}
|
||||
|
||||
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := getClientIP(r)
|
||||
clientIP := h.resolveClientIP(r)
|
||||
if !h.rateLimiter.Allow(clientIP) {
|
||||
log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded")
|
||||
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
||||
@@ -149,23 +152,57 @@ func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config
|
||||
return claims.Subject
|
||||
}
|
||||
|
||||
// getClientIP extracts the client IP address from the request.
|
||||
func getClientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
// resolveClientIP extracts the real client IP from the request.
|
||||
// When trustedProxies is non-empty and the direct peer is trusted,
|
||||
// it walks X-Forwarded-For right-to-left skipping trusted IPs.
|
||||
// Otherwise it returns RemoteAddr directly.
|
||||
func (h *AuthCallbackHandler) resolveClientIP(r *http.Request) string {
|
||||
remoteIP := extractHost(r.RemoteAddr)
|
||||
|
||||
if len(h.trustedProxies) == 0 || !isTrustedProxy(remoteIP, h.trustedProxies) {
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff == "" {
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
parts := strings.Split(xff, ",")
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
ip := strings.TrimSpace(parts[i])
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
if !isTrustedProxy(ip, h.trustedProxies) {
|
||||
return ip
|
||||
}
|
||||
return xff
|
||||
}
|
||||
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
// All IPs in XFF are trusted; return the leftmost as best guess.
|
||||
if first := strings.TrimSpace(parts[0]); first != "" {
|
||||
return first
|
||||
}
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
func extractHost(remoteAddr string) string {
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return r.RemoteAddr
|
||||
return remoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func isTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
|
||||
addr, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, prefix := range trusted {
|
||||
if prefix.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user