mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
cleanup
This commit is contained in:
262
proxy/internal/reverseproxy/handler.go
Normal file
262
proxy/internal/reverseproxy/handler.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
)
|
||||
|
||||
// buildHandler creates the main HTTP handler with router for static endpoints
|
||||
func (p *Proxy) buildHandler() http.Handler {
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Register static endpoints
|
||||
router.HandleFunc("/auth/callback", p.handleOIDCCallback).Methods("GET")
|
||||
|
||||
// Catch-all handler for dynamic proxy routing
|
||||
router.PathPrefix("/").HandlerFunc(p.handleProxyRequest)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// handleProxyRequest handles all dynamic proxy requests
|
||||
func (p *Proxy) handleProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := time.Now()
|
||||
|
||||
routeEntry := p.findRoute(r.Host, r.URL.Path)
|
||||
if routeEntry == nil {
|
||||
log.Warnf("No route found for host=%s path=%s", r.Host, r.URL.Path)
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
rw := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
|
||||
routeEntry.handler.ServeHTTP(rw, r)
|
||||
|
||||
if p.requestCallback != nil {
|
||||
duration := time.Since(startTime)
|
||||
|
||||
host := r.Host
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
// Get auth info from headers set by auth middleware
|
||||
authMechanism := r.Header.Get("X-Auth-Method")
|
||||
if authMechanism == "" {
|
||||
authMechanism = "none"
|
||||
}
|
||||
|
||||
userID := r.Header.Get("X-Auth-User-ID")
|
||||
|
||||
// Determine auth success based on status code
|
||||
authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden
|
||||
|
||||
// Extract source IP directly
|
||||
sourceIP := extractSourceIP(r)
|
||||
|
||||
data := RequestData{
|
||||
ServiceID: routeEntry.routeConfig.ID,
|
||||
Host: host,
|
||||
Path: r.URL.Path,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Method: r.Method,
|
||||
ResponseCode: int32(rw.statusCode),
|
||||
SourceIP: sourceIP,
|
||||
AuthMechanism: authMechanism,
|
||||
UserID: userID,
|
||||
AuthSuccess: authSuccess,
|
||||
}
|
||||
|
||||
p.requestCallback(data)
|
||||
}
|
||||
}
|
||||
|
||||
// findRoute finds the matching route for a given host and path
|
||||
func (p *Proxy) findRoute(host, path string) *routeEntry {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
// Strip port from host
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
// O(1) lookup by host
|
||||
routeConfig, exists := p.routes[host]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build list of route entries sorted by path specificity
|
||||
var entries []*routeEntry
|
||||
|
||||
// Create entries for each path mapping
|
||||
for routePath, target := range routeConfig.PathMappings {
|
||||
proxy := p.createProxy(routeConfig, target)
|
||||
|
||||
// ALWAYS wrap proxy with auth middleware (even if no auth configured)
|
||||
// This ensures consistent auth handling and logging
|
||||
handler := auth.Wrap(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.oidcHandler)
|
||||
|
||||
// Log auth configuration
|
||||
if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() {
|
||||
var authType string
|
||||
if routeConfig.AuthConfig.BasicAuth != nil {
|
||||
authType = "basic_auth"
|
||||
} else if routeConfig.AuthConfig.PIN != nil {
|
||||
authType = "pin"
|
||||
} else if routeConfig.AuthConfig.Bearer != nil {
|
||||
authType = "bearer_jwt"
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": routeConfig.ID,
|
||||
"auth_type": authType,
|
||||
}).Debug("Auth middleware enabled for route")
|
||||
} else {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": routeConfig.ID,
|
||||
}).Debug("No authentication configured for route")
|
||||
}
|
||||
|
||||
entries = append(entries, &routeEntry{
|
||||
routeConfig: routeConfig,
|
||||
path: routePath,
|
||||
target: target,
|
||||
proxy: proxy,
|
||||
handler: handler,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by path specificity (longest first)
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
pi, pj := entries[i].path, entries[j].path
|
||||
// Empty string or "/" goes last (catch-all)
|
||||
if pi == "" || pi == "/" {
|
||||
return false
|
||||
}
|
||||
if pj == "" || pj == "/" {
|
||||
return true
|
||||
}
|
||||
return len(pi) > len(pj)
|
||||
})
|
||||
|
||||
// Find first matching entry
|
||||
for _, entry := range entries {
|
||||
if entry.path == "" || entry.path == "/" {
|
||||
// Catch-all route
|
||||
return entry
|
||||
}
|
||||
if strings.HasPrefix(path, entry.path) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createProxy creates a reverse proxy for a target with the route's connection
|
||||
func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy {
|
||||
// Parse target URL
|
||||
targetURL, err := url.Parse("http://" + target)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse target URL %s: %v", target, err)
|
||||
// Return a proxy that returns 502
|
||||
return &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Create reverse proxy
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
|
||||
// Check if this is a defaultConn (for testing)
|
||||
if dc, ok := routeConfig.Conn.(*defaultConn); ok {
|
||||
// For defaultConn, use its dialer directly
|
||||
proxy.Transport = &http.Transport{
|
||||
DialContext: dc.dialer.DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
log.Infof("Using default network dialer for route %s (testing mode)", routeConfig.ID)
|
||||
} else {
|
||||
// Configure transport to use the provided connection (WireGuard, etc.)
|
||||
proxy.Transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
log.Debugf("Using custom connection for route %s to %s", routeConfig.ID, address)
|
||||
return routeConfig.Conn, nil
|
||||
},
|
||||
MaxIdleConns: 1,
|
||||
MaxIdleConnsPerHost: 1,
|
||||
IdleConnTimeout: 0, // Keep alive indefinitely
|
||||
DisableKeepAlives: false,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
log.Infof("Using custom connection for route %s", routeConfig.ID)
|
||||
}
|
||||
|
||||
// Custom error handler
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err)
|
||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||
}
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
||||
// handleOIDCCallback handles the global /auth/callback endpoint for all routes
|
||||
func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if OIDC handler is available
|
||||
if p.oidcHandler == nil {
|
||||
log.Error("OIDC callback received but no OIDC handler configured")
|
||||
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Use the OIDC handler's callback method
|
||||
handler := p.oidcHandler.HandleCallback()
|
||||
handler(w, r)
|
||||
}
|
||||
|
||||
// extractSourceIP extracts the source IP from the request
|
||||
func extractSourceIP(r *http.Request) string {
|
||||
// Try X-Forwarded-For header first
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP in the list
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
|
||||
// Try X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
return r.RemoteAddr[:idx]
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
Reference in New Issue
Block a user