This commit is contained in:
pascal
2026-01-15 14:54:33 +01:00
parent 12b38e25da
commit ed5f98da5b
22 changed files with 1511 additions and 1392 deletions

View File

@@ -0,0 +1,262 @@
package reverseproxy
import (
"context"
"net"
"net/http"
"net/http/httputil"
"net/url"
"sort"
"strings"
"time"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/auth"
)
// buildHandler creates the main HTTP handler with router for static endpoints
func (p *Proxy) buildHandler() http.Handler {
router := mux.NewRouter()
// Register static endpoints
router.HandleFunc("/auth/callback", p.handleOIDCCallback).Methods("GET")
// Catch-all handler for dynamic proxy routing
router.PathPrefix("/").HandlerFunc(p.handleProxyRequest)
return router
}
// handleProxyRequest handles all dynamic proxy requests
func (p *Proxy) handleProxyRequest(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
routeEntry := p.findRoute(r.Host, r.URL.Path)
if routeEntry == nil {
log.Warnf("No route found for host=%s path=%s", r.Host, r.URL.Path)
http.NotFound(w, r)
return
}
rw := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
routeEntry.handler.ServeHTTP(rw, r)
if p.requestCallback != nil {
duration := time.Since(startTime)
host := r.Host
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
// Get auth info from headers set by auth middleware
authMechanism := r.Header.Get("X-Auth-Method")
if authMechanism == "" {
authMechanism = "none"
}
userID := r.Header.Get("X-Auth-User-ID")
// Determine auth success based on status code
authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden
// Extract source IP directly
sourceIP := extractSourceIP(r)
data := RequestData{
ServiceID: routeEntry.routeConfig.ID,
Host: host,
Path: r.URL.Path,
DurationMs: duration.Milliseconds(),
Method: r.Method,
ResponseCode: int32(rw.statusCode),
SourceIP: sourceIP,
AuthMechanism: authMechanism,
UserID: userID,
AuthSuccess: authSuccess,
}
p.requestCallback(data)
}
}
// findRoute finds the matching route for a given host and path
func (p *Proxy) findRoute(host, path string) *routeEntry {
p.mu.RLock()
defer p.mu.RUnlock()
// Strip port from host
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
// O(1) lookup by host
routeConfig, exists := p.routes[host]
if !exists {
return nil
}
// Build list of route entries sorted by path specificity
var entries []*routeEntry
// Create entries for each path mapping
for routePath, target := range routeConfig.PathMappings {
proxy := p.createProxy(routeConfig, target)
// ALWAYS wrap proxy with auth middleware (even if no auth configured)
// This ensures consistent auth handling and logging
handler := auth.Wrap(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.oidcHandler)
// Log auth configuration
if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() {
var authType string
if routeConfig.AuthConfig.BasicAuth != nil {
authType = "basic_auth"
} else if routeConfig.AuthConfig.PIN != nil {
authType = "pin"
} else if routeConfig.AuthConfig.Bearer != nil {
authType = "bearer_jwt"
}
log.WithFields(log.Fields{
"route_id": routeConfig.ID,
"auth_type": authType,
}).Debug("Auth middleware enabled for route")
} else {
log.WithFields(log.Fields{
"route_id": routeConfig.ID,
}).Debug("No authentication configured for route")
}
entries = append(entries, &routeEntry{
routeConfig: routeConfig,
path: routePath,
target: target,
proxy: proxy,
handler: handler,
})
}
// Sort by path specificity (longest first)
sort.Slice(entries, func(i, j int) bool {
pi, pj := entries[i].path, entries[j].path
// Empty string or "/" goes last (catch-all)
if pi == "" || pi == "/" {
return false
}
if pj == "" || pj == "/" {
return true
}
return len(pi) > len(pj)
})
// Find first matching entry
for _, entry := range entries {
if entry.path == "" || entry.path == "/" {
// Catch-all route
return entry
}
if strings.HasPrefix(path, entry.path) {
return entry
}
}
return nil
}
// createProxy creates a reverse proxy for a target with the route's connection
func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy {
// Parse target URL
targetURL, err := url.Parse("http://" + target)
if err != nil {
log.Errorf("Failed to parse target URL %s: %v", target, err)
// Return a proxy that returns 502
return &httputil.ReverseProxy{
Director: func(req *http.Request) {},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "Bad Gateway", http.StatusBadGateway)
},
}
}
// Create reverse proxy
proxy := httputil.NewSingleHostReverseProxy(targetURL)
// Check if this is a defaultConn (for testing)
if dc, ok := routeConfig.Conn.(*defaultConn); ok {
// For defaultConn, use its dialer directly
proxy.Transport = &http.Transport{
DialContext: dc.dialer.DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
log.Infof("Using default network dialer for route %s (testing mode)", routeConfig.ID)
} else {
// Configure transport to use the provided connection (WireGuard, etc.)
proxy.Transport = &http.Transport{
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
log.Debugf("Using custom connection for route %s to %s", routeConfig.ID, address)
return routeConfig.Conn, nil
},
MaxIdleConns: 1,
MaxIdleConnsPerHost: 1,
IdleConnTimeout: 0, // Keep alive indefinitely
DisableKeepAlives: false,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
log.Infof("Using custom connection for route %s", routeConfig.ID)
}
// Custom error handler
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
}
return proxy
}
// handleOIDCCallback handles the global /auth/callback endpoint for all routes
func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
// Check if OIDC handler is available
if p.oidcHandler == nil {
log.Error("OIDC callback received but no OIDC handler configured")
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
return
}
// Use the OIDC handler's callback method
handler := p.oidcHandler.HandleCallback()
handler(w, r)
}
// extractSourceIP extracts the source IP from the request
func extractSourceIP(r *http.Request) string {
// Try X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Try X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}