mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
818 lines
24 KiB
Go
818 lines
24 KiB
Go
package reverseproxy
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/mux"
|
|
"golang.org/x/crypto/acme/autocert"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// Proxy wraps a reverse proxy with dynamic routing
|
|
type Proxy struct {
|
|
config Config
|
|
mu sync.RWMutex
|
|
routes map[string]*RouteConfig // key is host/domain (for fast O(1) lookup)
|
|
server *http.Server
|
|
httpServer *http.Server
|
|
autocertManager *autocert.Manager
|
|
isRunning bool
|
|
requestCallback RequestDataCallback
|
|
}
|
|
|
|
// Config holds the reverse proxy configuration
|
|
type Config struct {
|
|
// ListenAddress is the address to listen on for HTTPS (default ":443")
|
|
ListenAddress string
|
|
|
|
// HTTPListenAddress is the address for HTTP (default ":80")
|
|
// Used for ACME challenges when HTTPS is enabled, or as main listener when HTTPS is disabled
|
|
HTTPListenAddress string
|
|
|
|
// EnableHTTPS enables automatic HTTPS with Let's Encrypt
|
|
EnableHTTPS bool
|
|
|
|
// TLSEmail is the email for Let's Encrypt registration
|
|
TLSEmail string
|
|
|
|
// CertCacheDir is the directory to cache certificates (default "./certs")
|
|
CertCacheDir string
|
|
|
|
// RequestDataCallback is called for each proxied request with metrics
|
|
RequestDataCallback RequestDataCallback
|
|
|
|
// OIDCConfig is the global OIDC/OAuth configuration for authentication
|
|
// This is shared across all routes that use Bearer authentication
|
|
// If nil, routes with Bearer auth will fail to initialize
|
|
OIDCConfig *OIDCConfig
|
|
}
|
|
|
|
// OIDCConfig holds the global OIDC/OAuth configuration
|
|
type OIDCConfig struct {
|
|
// OIDC Provider settings
|
|
ProviderURL string `env:"NB_OIDC_PROVIDER_URL" json:"provider_url"` // Identity provider URL (e.g., "https://accounts.google.com")
|
|
ClientID string `env:"NB_OIDC_CLIENT_ID" json:"client_id"` // OAuth client ID
|
|
ClientSecret string `env:"NB_OIDC_CLIENT_SECRET" json:"client_secret"` // OAuth client secret (empty for public clients)
|
|
RedirectURL string `env:"NB_OIDC_REDIRECT_URL" json:"redirect_url"` // Redirect URL after auth (e.g., "http://localhost:54321/auth/callback")
|
|
Scopes []string `env:"NB_OIDC_SCOPES" json:"scopes"` // Requested scopes (default: ["openid", "profile", "email"])
|
|
|
|
// JWT Validation settings
|
|
JWTKeysLocation string `env:"NB_OIDC_JWT_KEYS_LOCATION" json:"jwt_keys_location"` // JWKS URL for fetching public keys
|
|
JWTIssuer string `env:"NB_OIDC_JWT_ISSUER" json:"jwt_issuer"` // Expected issuer claim
|
|
JWTAudience []string `env:"NB_OIDC_JWT_AUDIENCE" json:"jwt_audience"` // Expected audience claims
|
|
JWTIdpSignkeyRefreshEnabled bool `env:"NB_OIDC_JWT_IDP_SIGNKEY_REFRESH_ENABLED" json:"jwt_idp_signkey_refresh_enabled"` // Enable automatic refresh of signing keys
|
|
|
|
// Session settings
|
|
SessionCookieName string `env:"NB_OIDC_SESSION_COOKIE_NAME" json:"session_cookie_name"` // Cookie name for storing session (default: "auth_session")
|
|
}
|
|
|
|
// RouteConfig defines a routing configuration
|
|
type RouteConfig struct {
|
|
// ID is a unique identifier for this route
|
|
ID string
|
|
|
|
// Domain is the domain to listen on (e.g., "example.com" or "*" for all)
|
|
Domain string
|
|
|
|
// PathMappings defines paths that should be forwarded to specific ports
|
|
// Key is the path prefix (e.g., "/", "/api", "/admin")
|
|
// Value is the target IP:port (e.g., "192.168.1.100:3000")
|
|
// Must have at least one entry. Use "/" or "" for the default/catch-all route.
|
|
PathMappings map[string]string
|
|
|
|
// Conn is the network connection to use for this route
|
|
// This allows routing through specific tunnels (e.g., WireGuard) per route
|
|
// This connection will be reused for all requests to this route
|
|
Conn net.Conn
|
|
|
|
// AuthConfig is optional authentication configuration for this route
|
|
// Configure ONE of: BasicAuth, PIN, or Bearer (JWT/OIDC)
|
|
// If nil, requests pass through without authentication
|
|
AuthConfig *AuthConfig
|
|
|
|
// AuthRejectResponse is an optional custom response for authentication failures
|
|
// If nil, returns 401 Unauthorized with WWW-Authenticate header
|
|
AuthRejectResponse func(w http.ResponseWriter, r *http.Request)
|
|
}
|
|
|
|
// routeEntry represents a compiled route with its proxy
|
|
type routeEntry struct {
|
|
routeConfig *RouteConfig
|
|
path string
|
|
target string
|
|
proxy *httputil.ReverseProxy
|
|
handler http.Handler // handler wraps proxy with middleware (auth, logging, etc.)
|
|
}
|
|
|
|
// New creates a new reverse proxy
|
|
func New(config Config) (*Proxy, error) {
|
|
// Set defaults
|
|
if config.ListenAddress == "" {
|
|
config.ListenAddress = ":443"
|
|
}
|
|
if config.HTTPListenAddress == "" {
|
|
config.HTTPListenAddress = ":80"
|
|
}
|
|
if config.CertCacheDir == "" {
|
|
config.CertCacheDir = "./certs"
|
|
}
|
|
|
|
// Validate HTTPS config
|
|
if config.EnableHTTPS {
|
|
if config.TLSEmail == "" {
|
|
return nil, fmt.Errorf("TLSEmail is required when EnableHTTPS is true")
|
|
}
|
|
}
|
|
|
|
// Set default OIDC session cookie name if not provided
|
|
if config.OIDCConfig != nil && config.OIDCConfig.SessionCookieName == "" {
|
|
config.OIDCConfig.SessionCookieName = "auth_session"
|
|
}
|
|
|
|
p := &Proxy{
|
|
config: config,
|
|
routes: make(map[string]*RouteConfig),
|
|
isRunning: false,
|
|
requestCallback: config.RequestDataCallback,
|
|
}
|
|
|
|
return p, nil
|
|
}
|
|
|
|
// Start starts the reverse proxy server
|
|
func (p *Proxy) Start() error {
|
|
p.mu.Lock()
|
|
if p.isRunning {
|
|
p.mu.Unlock()
|
|
return fmt.Errorf("reverse proxy already running")
|
|
}
|
|
p.isRunning = true
|
|
p.mu.Unlock()
|
|
|
|
// Build the main HTTP handler
|
|
handler := p.buildHandler()
|
|
|
|
if p.config.EnableHTTPS {
|
|
// Setup autocert manager with dynamic host policy
|
|
p.autocertManager = &autocert.Manager{
|
|
Cache: autocert.DirCache(p.config.CertCacheDir),
|
|
Prompt: autocert.AcceptTOS,
|
|
Email: p.config.TLSEmail,
|
|
HostPolicy: p.dynamicHostPolicy, // Use dynamic policy based on routes
|
|
}
|
|
|
|
// Start HTTP server for ACME challenges
|
|
p.httpServer = &http.Server{
|
|
Addr: p.config.HTTPListenAddress,
|
|
Handler: p.autocertManager.HTTPHandler(nil),
|
|
}
|
|
|
|
go func() {
|
|
log.Infof("Starting HTTP server on %s for ACME challenges", p.config.HTTPListenAddress)
|
|
if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
log.Errorf("HTTP server error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Start HTTPS server
|
|
p.server = &http.Server{
|
|
Addr: p.config.ListenAddress,
|
|
Handler: handler,
|
|
TLSConfig: p.autocertManager.TLSConfig(),
|
|
}
|
|
|
|
go func() {
|
|
log.Infof("Starting HTTPS server on %s", p.config.ListenAddress)
|
|
if err := p.server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
|
log.Errorf("HTTPS server error: %v", err)
|
|
p.mu.Lock()
|
|
p.isRunning = false
|
|
p.mu.Unlock()
|
|
}
|
|
}()
|
|
} else {
|
|
// Start HTTP server only
|
|
p.server = &http.Server{
|
|
Addr: p.config.HTTPListenAddress,
|
|
Handler: handler,
|
|
}
|
|
|
|
go func() {
|
|
log.Infof("Starting HTTP server on %s", p.config.HTTPListenAddress)
|
|
if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
|
log.Errorf("HTTP server error: %v", err)
|
|
p.mu.Lock()
|
|
p.isRunning = false
|
|
p.mu.Unlock()
|
|
}
|
|
}()
|
|
}
|
|
|
|
log.Infof("Reverse proxy started with %d route(s)", len(p.routes))
|
|
return nil
|
|
}
|
|
|
|
// dynamicHostPolicy is a custom host policy that allows certificates for any domain
|
|
// that has a configured route
|
|
func (p *Proxy) dynamicHostPolicy(ctx context.Context, host string) error {
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
|
|
// Strip port if present
|
|
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
|
host = host[:idx]
|
|
}
|
|
|
|
// O(1) lookup for exact domain match
|
|
if _, exists := p.routes[host]; exists {
|
|
log.Infof("Allowing certificate for domain: %s", host)
|
|
return nil
|
|
}
|
|
|
|
log.Warnf("Rejecting certificate request for unknown domain: %s", host)
|
|
return fmt.Errorf("domain %s not configured in routes", host)
|
|
}
|
|
|
|
// Stop gracefully stops the reverse proxy
|
|
func (p *Proxy) Stop(ctx context.Context) error {
|
|
p.mu.Lock()
|
|
if !p.isRunning {
|
|
p.mu.Unlock()
|
|
return fmt.Errorf("reverse proxy not running")
|
|
}
|
|
p.mu.Unlock()
|
|
|
|
log.Info("Stopping reverse proxy...")
|
|
|
|
// Stop HTTPS server
|
|
if p.server != nil {
|
|
if err := p.server.Shutdown(ctx); err != nil {
|
|
return fmt.Errorf("failed to shutdown HTTPS server: %w", err)
|
|
}
|
|
}
|
|
|
|
// Stop HTTP server (ACME challenge server)
|
|
if p.httpServer != nil {
|
|
if err := p.httpServer.Shutdown(ctx); err != nil {
|
|
return fmt.Errorf("failed to shutdown HTTP server: %w", err)
|
|
}
|
|
}
|
|
|
|
p.mu.Lock()
|
|
p.isRunning = false
|
|
p.mu.Unlock()
|
|
|
|
log.Info("Reverse proxy stopped")
|
|
return nil
|
|
}
|
|
|
|
// buildHandler creates the main HTTP handler with router for static endpoints
|
|
func (p *Proxy) buildHandler() http.Handler {
|
|
router := mux.NewRouter()
|
|
|
|
// Register static endpoints
|
|
router.HandleFunc("/auth/callback", p.handleOIDCCallback).Methods("GET")
|
|
|
|
// Catch-all handler for dynamic proxy routing
|
|
router.PathPrefix("/").HandlerFunc(p.handleProxyRequest)
|
|
|
|
return router
|
|
}
|
|
|
|
// handleProxyRequest handles all dynamic proxy requests
|
|
func (p *Proxy) handleProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|
startTime := time.Now()
|
|
|
|
routeEntry := p.findRoute(r.Host, r.URL.Path)
|
|
if routeEntry == nil {
|
|
log.Warnf("No route found for host=%s path=%s", r.Host, r.URL.Path)
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
|
|
rw := &responseWriter{
|
|
ResponseWriter: w,
|
|
statusCode: http.StatusOK,
|
|
}
|
|
|
|
routeEntry.handler.ServeHTTP(rw, r)
|
|
|
|
if p.requestCallback != nil {
|
|
duration := time.Since(startTime)
|
|
|
|
host := r.Host
|
|
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
|
host = host[:idx]
|
|
}
|
|
|
|
authMechanism := r.Header.Get("X-Auth-Method")
|
|
if authMechanism == "" {
|
|
authMechanism = "none"
|
|
}
|
|
|
|
// Determine auth success based on status code
|
|
authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden
|
|
|
|
// Extract user ID (this would need to be enhanced to extract from tokens/headers)
|
|
_, userID, _ := extractAuthInfo(r, rw.statusCode)
|
|
|
|
data := RequestData{
|
|
ServiceID: routeEntry.routeConfig.ID,
|
|
Host: host,
|
|
Path: r.URL.Path,
|
|
DurationMs: duration.Milliseconds(),
|
|
Method: r.Method,
|
|
ResponseCode: int32(rw.statusCode),
|
|
SourceIP: extractSourceIP(r),
|
|
AuthMechanism: authMechanism,
|
|
UserID: userID,
|
|
AuthSuccess: authSuccess,
|
|
}
|
|
|
|
p.requestCallback(data)
|
|
}
|
|
}
|
|
|
|
// findRoute finds the matching route for a given host and path
|
|
func (p *Proxy) findRoute(host, path string) *routeEntry {
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
|
|
// Strip port from host
|
|
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
|
host = host[:idx]
|
|
}
|
|
|
|
// O(1) lookup by host
|
|
routeConfig, exists := p.routes[host]
|
|
if !exists {
|
|
return nil
|
|
}
|
|
|
|
// Build list of route entries sorted by path specificity
|
|
var entries []*routeEntry
|
|
|
|
// Create entries for each path mapping
|
|
for routePath, target := range routeConfig.PathMappings {
|
|
proxy := p.createProxy(routeConfig, target)
|
|
|
|
// ALWAYS wrap proxy with auth middleware (even if no auth configured)
|
|
// This ensures consistent auth handling and logging
|
|
handler := wrapWithAuth(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.config.OIDCConfig)
|
|
|
|
// Log auth configuration
|
|
if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() {
|
|
var authType string
|
|
if routeConfig.AuthConfig.BasicAuth != nil {
|
|
authType = "basic_auth"
|
|
} else if routeConfig.AuthConfig.PIN != nil {
|
|
authType = "pin"
|
|
} else if routeConfig.AuthConfig.Bearer != nil {
|
|
authType = "bearer_jwt"
|
|
}
|
|
log.WithFields(log.Fields{
|
|
"route_id": routeConfig.ID,
|
|
"auth_type": authType,
|
|
}).Debug("Auth middleware enabled for route")
|
|
} else {
|
|
log.WithFields(log.Fields{
|
|
"route_id": routeConfig.ID,
|
|
}).Debug("No authentication configured for route")
|
|
}
|
|
|
|
entries = append(entries, &routeEntry{
|
|
routeConfig: routeConfig,
|
|
path: routePath,
|
|
target: target,
|
|
proxy: proxy,
|
|
handler: handler,
|
|
})
|
|
}
|
|
|
|
// Sort by path specificity (longest first)
|
|
sort.Slice(entries, func(i, j int) bool {
|
|
pi, pj := entries[i].path, entries[j].path
|
|
// Empty string or "/" goes last (catch-all)
|
|
if pi == "" || pi == "/" {
|
|
return false
|
|
}
|
|
if pj == "" || pj == "/" {
|
|
return true
|
|
}
|
|
return len(pi) > len(pj)
|
|
})
|
|
|
|
// Find first matching entry
|
|
for _, entry := range entries {
|
|
if entry.path == "" || entry.path == "/" {
|
|
// Catch-all route
|
|
return entry
|
|
}
|
|
if strings.HasPrefix(path, entry.path) {
|
|
return entry
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// createProxy creates a reverse proxy for a target with the route's connection
|
|
func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy {
|
|
// Parse target URL
|
|
targetURL, err := url.Parse("http://" + target)
|
|
if err != nil {
|
|
log.Errorf("Failed to parse target URL %s: %v", target, err)
|
|
// Return a proxy that returns 502
|
|
return &httputil.ReverseProxy{
|
|
Director: func(req *http.Request) {},
|
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
},
|
|
}
|
|
}
|
|
|
|
// Create reverse proxy
|
|
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
|
|
|
// Check if this is a defaultConn (for testing)
|
|
if dc, ok := routeConfig.Conn.(*defaultConn); ok {
|
|
// For defaultConn, use its dialer directly
|
|
proxy.Transport = &http.Transport{
|
|
DialContext: dc.dialer.DialContext,
|
|
MaxIdleConns: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
log.Infof("Using default network dialer for route %s (testing mode)", routeConfig.ID)
|
|
} else {
|
|
// Configure transport to use the provided connection (WireGuard, etc.)
|
|
proxy.Transport = &http.Transport{
|
|
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
log.Debugf("Using custom connection for route %s to %s", routeConfig.ID, address)
|
|
return routeConfig.Conn, nil
|
|
},
|
|
MaxIdleConns: 1,
|
|
MaxIdleConnsPerHost: 1,
|
|
IdleConnTimeout: 0, // Keep alive indefinitely
|
|
DisableKeepAlives: false,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
log.Infof("Using custom connection for route %s", routeConfig.ID)
|
|
}
|
|
|
|
// Custom error handler
|
|
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
|
log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err)
|
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
}
|
|
|
|
return proxy
|
|
}
|
|
|
|
// AddRoute adds a new route configuration
|
|
func (p *Proxy) AddRoute(route *RouteConfig) error {
|
|
if route == nil {
|
|
return fmt.Errorf("route cannot be nil")
|
|
}
|
|
if route.ID == "" {
|
|
return fmt.Errorf("route ID is required")
|
|
}
|
|
if route.Domain == "" {
|
|
return fmt.Errorf("route Domain is required")
|
|
}
|
|
if len(route.PathMappings) == 0 {
|
|
return fmt.Errorf("route must have at least one path mapping")
|
|
}
|
|
if route.Conn == nil {
|
|
return fmt.Errorf("route connection (Conn) is required")
|
|
}
|
|
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
// Check if route already exists for this domain
|
|
if _, exists := p.routes[route.Domain]; exists {
|
|
return fmt.Errorf("route for domain %s already exists", route.Domain)
|
|
}
|
|
|
|
// Add route with domain as key
|
|
p.routes[route.Domain] = route
|
|
|
|
log.WithFields(log.Fields{
|
|
"route_id": route.ID,
|
|
"domain": route.Domain,
|
|
"paths": len(route.PathMappings),
|
|
}).Info("Added route")
|
|
|
|
// Note: With this architecture, we don't need to reload the server
|
|
// The handler dynamically looks up routes on each request
|
|
// Certificates will be obtained automatically when the domain is first accessed
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemoveRoute removes a route
|
|
func (p *Proxy) RemoveRoute(domain string) error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
// Check if route exists
|
|
if _, exists := p.routes[domain]; !exists {
|
|
return fmt.Errorf("route for domain %s not found", domain)
|
|
}
|
|
|
|
// Remove route
|
|
delete(p.routes, domain)
|
|
|
|
log.Infof("Removed route for domain: %s", domain)
|
|
return nil
|
|
}
|
|
|
|
// UpdateRoute updates an existing route
|
|
func (p *Proxy) UpdateRoute(route *RouteConfig) error {
|
|
if route == nil {
|
|
return fmt.Errorf("route cannot be nil")
|
|
}
|
|
if route.ID == "" {
|
|
return fmt.Errorf("route ID is required")
|
|
}
|
|
if route.Domain == "" {
|
|
return fmt.Errorf("route Domain is required")
|
|
}
|
|
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
// Check if route exists for this domain
|
|
if _, exists := p.routes[route.Domain]; !exists {
|
|
return fmt.Errorf("route for domain %s not found", route.Domain)
|
|
}
|
|
|
|
// Update route using domain as key
|
|
p.routes[route.Domain] = route
|
|
|
|
log.WithFields(log.Fields{
|
|
"route_id": route.ID,
|
|
"domain": route.Domain,
|
|
"paths": len(route.PathMappings),
|
|
}).Info("Updated route")
|
|
|
|
return nil
|
|
}
|
|
|
|
// ListRoutes returns a list of all configured domains
|
|
func (p *Proxy) ListRoutes() []string {
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
|
|
domains := make([]string, 0, len(p.routes))
|
|
for domain := range p.routes {
|
|
domains = append(domains, domain)
|
|
}
|
|
return domains
|
|
}
|
|
|
|
// GetRoute returns a route configuration by domain
|
|
func (p *Proxy) GetRoute(domain string) (*RouteConfig, error) {
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
|
|
route, exists := p.routes[domain]
|
|
if !exists {
|
|
return nil, fmt.Errorf("route for domain %s not found", domain)
|
|
}
|
|
|
|
return route, nil
|
|
}
|
|
|
|
// IsRunning returns whether the proxy is running
|
|
func (p *Proxy) IsRunning() bool {
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
return p.isRunning
|
|
}
|
|
|
|
// GetConfig returns the proxy configuration
|
|
func (p *Proxy) GetConfig() Config {
|
|
return p.config
|
|
}
|
|
|
|
// responseWriter wraps http.ResponseWriter to capture status code
|
|
type responseWriter struct {
|
|
http.ResponseWriter
|
|
statusCode int
|
|
}
|
|
|
|
func (rw *responseWriter) WriteHeader(code int) {
|
|
rw.statusCode = code
|
|
rw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
// extractSourceIP extracts the source IP from the request
|
|
func extractSourceIP(r *http.Request) string {
|
|
// Try X-Forwarded-For header first
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
// Take the first IP in the list
|
|
if idx := strings.Index(xff, ","); idx != -1 {
|
|
return strings.TrimSpace(xff[:idx])
|
|
}
|
|
return strings.TrimSpace(xff)
|
|
}
|
|
|
|
// Try X-Real-IP header
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return strings.TrimSpace(xri)
|
|
}
|
|
|
|
// Fall back to RemoteAddr
|
|
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
|
return r.RemoteAddr[:idx]
|
|
}
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
// extractAuthInfo extracts authentication information from the request
|
|
// Returns: authMechanism, userID, authSuccess
|
|
func extractAuthInfo(r *http.Request, statusCode int) (string, string, bool) {
|
|
// Check if authentication succeeded based on status code
|
|
// 401 = Unauthorized, 403 = Forbidden
|
|
authSuccess := statusCode != http.StatusUnauthorized && statusCode != http.StatusForbidden
|
|
|
|
// Check for Bearer token (JWT, OAuth2, etc.)
|
|
if auth := r.Header.Get("Authorization"); auth != "" {
|
|
if strings.HasPrefix(auth, "Bearer ") {
|
|
// Extract user ID from JWT if possible (you may want to decode the JWT here)
|
|
// For now, we'll just indicate it's a bearer token
|
|
return "bearer", extractUserIDFromBearer(auth), authSuccess
|
|
}
|
|
if strings.HasPrefix(auth, "Basic ") {
|
|
// Basic authentication
|
|
return "basic", extractUserIDFromBasic(auth), authSuccess
|
|
}
|
|
// Other authorization schemes
|
|
return "other", "", authSuccess
|
|
}
|
|
|
|
// Check for API key in headers
|
|
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
|
|
return "api_key", "", authSuccess
|
|
}
|
|
if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" {
|
|
return "api_key", "", authSuccess
|
|
}
|
|
|
|
// Check for mutual TLS (client certificate)
|
|
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
|
// Extract Common Name from client certificate
|
|
cn := r.TLS.PeerCertificates[0].Subject.CommonName
|
|
return "mtls", cn, authSuccess
|
|
}
|
|
|
|
// Check for session cookie (common in web apps)
|
|
if cookie, err := r.Cookie("session"); err == nil && cookie.Value != "" {
|
|
return "session", "", authSuccess
|
|
}
|
|
|
|
// No authentication detected
|
|
return "none", "", authSuccess
|
|
}
|
|
|
|
// extractUserIDFromBearer attempts to extract user ID from Bearer token
|
|
// Decodes the JWT (without verification) to extract the user ID from standard claims
|
|
func extractUserIDFromBearer(auth string) string {
|
|
// Remove "Bearer " prefix
|
|
tokenString := strings.TrimPrefix(auth, "Bearer ")
|
|
if tokenString == "" {
|
|
return ""
|
|
}
|
|
|
|
// JWT format: header.payload.signature
|
|
// We only need the payload to extract user ID (no verification needed here)
|
|
parts := strings.Split(tokenString, ".")
|
|
if len(parts) != 3 {
|
|
log.Debug("Invalid JWT format: expected 3 parts")
|
|
return ""
|
|
}
|
|
|
|
// Decode the payload (second part)
|
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
if err != nil {
|
|
log.WithError(err).Debug("Failed to decode JWT payload")
|
|
return ""
|
|
}
|
|
|
|
// Parse JSON payload
|
|
var claims map[string]interface{}
|
|
if err := json.Unmarshal(payload, &claims); err != nil {
|
|
log.WithError(err).Debug("Failed to parse JWT claims")
|
|
return ""
|
|
}
|
|
|
|
// Try standard user ID claims in order of preference
|
|
// 1. "sub" (standard JWT subject claim)
|
|
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
|
return sub
|
|
}
|
|
|
|
// 2. "user_id" (common in some systems)
|
|
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
|
return userID
|
|
}
|
|
|
|
// 3. "email" (fallback)
|
|
if email, ok := claims["email"].(string); ok && email != "" {
|
|
return email
|
|
}
|
|
|
|
// 4. "preferred_username" (used by some OIDC providers)
|
|
if username, ok := claims["preferred_username"].(string); ok && username != "" {
|
|
return username
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// extractUserIDFromBasic extracts username from Basic auth header
|
|
func extractUserIDFromBasic(auth string) string {
|
|
// Basic auth format: "Basic base64(username:password)"
|
|
_ = strings.TrimPrefix(auth, "Basic ")
|
|
// Note: We're not decoding it here for security reasons
|
|
// The upstream service should handle the actual authentication
|
|
// We just note that basic auth was used
|
|
return ""
|
|
}
|
|
|
|
// defaultConn is a lazy connection wrapper that uses the standard network dialer
|
|
// This is useful for testing or development when not using WireGuard tunnels
|
|
type defaultConn struct {
|
|
dialer *net.Dialer
|
|
mu sync.Mutex
|
|
conns map[string]net.Conn // cache connections by "network:address"
|
|
}
|
|
|
|
func (dc *defaultConn) Read(b []byte) (n int, err error) {
|
|
return 0, fmt.Errorf("Read not supported on defaultConn - use dial via Transport")
|
|
}
|
|
|
|
func (dc *defaultConn) Write(b []byte) (n int, err error) {
|
|
return 0, fmt.Errorf("Write not supported on defaultConn - use dial via Transport")
|
|
}
|
|
|
|
func (dc *defaultConn) Close() error {
|
|
dc.mu.Lock()
|
|
defer dc.mu.Unlock()
|
|
|
|
for _, conn := range dc.conns {
|
|
conn.Close()
|
|
}
|
|
dc.conns = make(map[string]net.Conn)
|
|
return nil
|
|
}
|
|
|
|
func (dc *defaultConn) LocalAddr() net.Addr { return nil }
|
|
func (dc *defaultConn) RemoteAddr() net.Addr { return nil }
|
|
func (dc *defaultConn) SetDeadline(t time.Time) error { return nil }
|
|
func (dc *defaultConn) SetReadDeadline(t time.Time) error { return nil }
|
|
func (dc *defaultConn) SetWriteDeadline(t time.Time) error { return nil }
|
|
|
|
// NewDefaultConn creates a connection wrapper that uses the standard network dialer
|
|
// This is useful for testing or development when not using WireGuard tunnels
|
|
// The actual dialing happens when the HTTP Transport calls DialContext
|
|
func NewDefaultConn() net.Conn {
|
|
return &defaultConn{
|
|
dialer: &net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
},
|
|
conns: make(map[string]net.Conn),
|
|
}
|
|
}
|
|
|
|
// handleOIDCCallback handles the global /auth/callback endpoint for all routes
|
|
func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
|
// Check if OIDC is configured globally
|
|
if p.config.OIDCConfig == nil {
|
|
log.Error("OIDC callback received but no OIDC config found")
|
|
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Use the HandleOIDCCallback function from auth.go with global config
|
|
handler := HandleOIDCCallback(p.config.OIDCConfig)
|
|
handler(w, r)
|
|
}
|