mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 01:36:46 +00:00
[management, reverse proxy] Add reverse proxy feature (#5291)
* implement reverse proxy --------- Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk> Co-authored-by: mlsmaycon <mlsmaycon@gmail.com> Co-authored-by: Eduard Gert <kontakt@eduardgert.de> Co-authored-by: Viktor Liu <viktor@netbird.io> Co-authored-by: Diego Noguês <diego.sure@gmail.com> Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com> Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com> Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
This commit is contained in:
208
management/server/http/handlers/proxy/auth.go
Normal file
208
management/server/http/handlers/proxy/auth.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
)
|
||||
|
||||
// AuthCallbackHandler handles OAuth callbacks for proxy authentication.
|
||||
type AuthCallbackHandler struct {
|
||||
proxyService *nbgrpc.ProxyServiceServer
|
||||
rateLimiter *middleware.APIRateLimiter
|
||||
trustedProxies []netip.Prefix
|
||||
}
|
||||
|
||||
// NewAuthCallbackHandler creates a new OAuth callback handler.
|
||||
func NewAuthCallbackHandler(proxyService *nbgrpc.ProxyServiceServer, trustedProxies []netip.Prefix) *AuthCallbackHandler {
|
||||
rateLimiterConfig := &middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: 10,
|
||||
Burst: 15,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
return &AuthCallbackHandler{
|
||||
proxyService: proxyService,
|
||||
rateLimiter: middleware.NewAPIRateLimiter(rateLimiterConfig),
|
||||
trustedProxies: trustedProxies,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterEndpoints registers the OAuth callback endpoint.
|
||||
func (h *AuthCallbackHandler) RegisterEndpoints(router *mux.Router) {
|
||||
router.HandleFunc(types.ProxyCallbackEndpoint, h.handleCallback).Methods(http.MethodGet)
|
||||
}
|
||||
|
||||
func (h *AuthCallbackHandler) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
clientIP := h.resolveClientIP(r)
|
||||
if !h.rateLimiter.Allow(clientIP) {
|
||||
log.WithField("client_ip", clientIP).Warn("OAuth callback rate limit exceeded")
|
||||
http.Error(w, "Too many requests. Please try again later.", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
codeVerifier, originalURL, err := h.proxyService.ValidateState(state)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("OAuth callback state validation failed")
|
||||
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURL, err := url.Parse(originalURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to parse redirect URL")
|
||||
http.Error(w, "Invalid redirect URL", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
oidcConfig := h.proxyService.GetOIDCConfig()
|
||||
|
||||
provider, err := oidc.NewProvider(r.Context(), oidcConfig.Issuer)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create OIDC provider")
|
||||
http.Error(w, "Failed to create OIDC provider", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
token, err := (&oauth2.Config{
|
||||
ClientID: oidcConfig.ClientID,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: oidcConfig.CallbackURL,
|
||||
}).Exchange(r.Context(), r.URL.Query().Get("code"), oauth2.VerifierOption(codeVerifier))
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to exchange code for token")
|
||||
http.Error(w, "Failed to exchange code for token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
userID := extractUserIDFromToken(r.Context(), provider, oidcConfig, token)
|
||||
if userID == "" {
|
||||
log.Error("Failed to extract user ID from OIDC token")
|
||||
http.Error(w, "Failed to validate token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Group validation is performed by the proxy via ValidateSession gRPC call.
|
||||
// This allows the proxy to show 403 pages directly without redirect dance.
|
||||
|
||||
sessionToken, err := h.proxyService.GenerateSessionToken(r.Context(), redirectURL.Hostname(), userID, auth.MethodOIDC)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create session token")
|
||||
redirectURL.Scheme = "https"
|
||||
query := redirectURL.Query()
|
||||
query.Set("error", "access_denied")
|
||||
query.Set("error_description", "Service configuration error")
|
||||
redirectURL.RawQuery = query.Encode()
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
redirectURL.Scheme = "https"
|
||||
|
||||
query := redirectURL.Query()
|
||||
query.Set("session_token", sessionToken)
|
||||
redirectURL.RawQuery = query.Encode()
|
||||
|
||||
log.WithField("redirect", redirectURL.Host).Debug("OAuth callback: redirecting user with session token")
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
func extractUserIDFromToken(ctx context.Context, provider *oidc.Provider, config nbgrpc.ProxyOIDCConfig, token *oauth2.Token) string {
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
log.Warn("No id_token in OIDC response")
|
||||
return ""
|
||||
}
|
||||
|
||||
verifier := provider.Verifier(&oidc.Config{
|
||||
ClientID: config.ClientID,
|
||||
})
|
||||
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("Failed to verify ID token")
|
||||
return ""
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Subject string `json:"sub"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
log.WithError(err).Warn("Failed to extract claims from ID token")
|
||||
return ""
|
||||
}
|
||||
|
||||
return claims.Subject
|
||||
}
|
||||
|
||||
// resolveClientIP extracts the real client IP from the request.
|
||||
// When trustedProxies is non-empty and the direct peer is trusted,
|
||||
// it walks X-Forwarded-For right-to-left skipping trusted IPs.
|
||||
// Otherwise it returns RemoteAddr directly.
|
||||
func (h *AuthCallbackHandler) resolveClientIP(r *http.Request) string {
|
||||
remoteIP := extractHost(r.RemoteAddr)
|
||||
|
||||
if len(h.trustedProxies) == 0 || !isTrustedProxy(remoteIP, h.trustedProxies) {
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
xff := r.Header.Get("X-Forwarded-For")
|
||||
if xff == "" {
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
parts := strings.Split(xff, ",")
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
ip := strings.TrimSpace(parts[i])
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
if !isTrustedProxy(ip, h.trustedProxies) {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// All IPs in XFF are trusted; return the leftmost as best guess.
|
||||
if first := strings.TrimSpace(parts[0]); first != "" {
|
||||
return first
|
||||
}
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
func extractHost(remoteAddr string) string {
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return remoteAddr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func isTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
|
||||
addr, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
for _, prefix := range trusted {
|
||||
if prefix.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user