Validate trusted proxies in OAuth callback getClientIP

This commit is contained in:
Viktor Liu
2026-02-12 21:15:41 +08:00
parent 7fdb824a37
commit 9554934b92
6 changed files with 126 additions and 55 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"net"
"net/http"
"net/netip"
"net/url"
"strings"
"time"
@@ -21,12 +22,13 @@ import (
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
type AuthCallbackHandler struct {
proxyService *nbgrpc.ProxyServiceServer
rateLimiter *middleware.APIRateLimiter
proxyService *nbgrpc.ProxyServiceServer
rateLimiter *middleware.APIRateLimiter
trustedProxies []netip.Prefix
}
// NewAuthCallbackHandler creates a new OAuth callback handler.
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallbackHandler {
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer, trustedProxies []netip.Prefix) *AuthCallbackHandler {
rateLimiterConfig := &middleware.RateLimiterConfig{
RequestsPerMinute: 10,
Burst: 15,
@@ -35,8 +37,9 @@ func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer) *AuthCallba
}
return &AuthCallbackHandler{
proxyService: proxyService,
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
proxyService: proxyService,
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
trustedProxies: trustedProxies,
}
}
@@ -46,7 +49,7 @@ func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
}
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
clientIP := getClientIP(r)
clientIP := h.resolveClientIP(r)
if !h.rateLimiter.Allow(clientIP) {
log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded")
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
@@ -149,23 +152,57 @@ func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config
return claims.Subject
}
// getClientIP extracts the client IP address from the request.
func getClientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
// resolveClientIP extracts the real client IP from the request.
// When trustedProxies is non-empty and the direct peer is trusted,
// it walks X-Forwarded-For right-to-left skipping trusted IPs.
// Otherwise it returns RemoteAddr directly.
func (h *AuthCallbackHandler) resolveClientIP(r *http.Request) string {
remoteIP := extractHost(r.RemoteAddr)
if len(h.trustedProxies) == 0 || !isTrustedProxy(remoteIP, h.trustedProxies) {
return remoteIP
}
xff := r.Header.Get("X-Forwarded-For")
if xff == "" {
return remoteIP
}
parts := strings.Split(xff, ",")
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if ip == "" {
continue
}
if !isTrustedProxy(ip, h.trustedProxies) {
return ip
}
return xff
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
// All IPs in XFF are trusted; return the leftmost as best guess.
if first := strings.TrimSpace(parts[0]); first != "" {
return first
}
return remoteIP
}
// Fall back to RemoteAddr
host, _, err := net.SplitHostPort(r.RemoteAddr)
func extractHost(remoteAddr string) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return r.RemoteAddr
return remoteAddr
}
return host
}
func isTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
addr, err := netip.ParseAddr(ipStr)
if err != nil {
return false
}
for _, prefix := range trusted {
if prefix.Contains(addr) {
return true
}
}
return false
}