Files
netbird/proxy/internal/reverseproxy/handler.go
2026-01-16 12:01:52 +01:00

225 lines
5.8 KiB
Go

package reverseproxy
import (
"net/http"
"net/http/httputil"
"net/url"
"sort"
"strings"
"time"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/auth"
)
// buildHandler creates the main HTTP handler with router for static endpoints
func (p *Proxy) buildHandler() http.Handler {
router := mux.NewRouter()
// Register static endpoints
router.HandleFunc("/auth/callback", p.handleOIDCCallback).Methods("GET")
// Catch-all handler for dynamic proxy routing
router.PathPrefix("/").HandlerFunc(p.handleProxyRequest)
return router
}
// handleProxyRequest handles all dynamic proxy requests
func (p *Proxy) handleProxyRequest(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
routeEntry := p.findRoute(r.Host, r.URL.Path)
if routeEntry == nil {
log.Warnf("No route found for host=%s path=%s", r.Host, r.URL.Path)
http.NotFound(w, r)
return
}
rw := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
routeEntry.handler.ServeHTTP(rw, r)
if p.requestCallback != nil {
duration := time.Since(startTime)
host := r.Host
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
// TODO: extract logging data
authMechanism := r.Header.Get("X-Auth-Method")
if authMechanism == "" {
authMechanism = "none"
}
userID := r.Header.Get("X-Auth-User-ID")
authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden
sourceIP := extractSourceIP(r)
data := RequestData{
ServiceID: routeEntry.routeConfig.ID,
Host: host,
Path: r.URL.Path,
DurationMs: duration.Milliseconds(),
Method: r.Method,
ResponseCode: int32(rw.statusCode),
SourceIP: sourceIP,
AuthMechanism: authMechanism,
UserID: userID,
AuthSuccess: authSuccess,
}
p.requestCallback(data)
}
}
// findRoute finds the matching route for a given host and path
func (p *Proxy) findRoute(host, path string) *routeEntry {
p.mu.RLock()
defer p.mu.RUnlock()
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
routeConfig, exists := p.routes[host]
if !exists {
return nil
}
var entries []*routeEntry
for routePath, target := range routeConfig.PathMappings {
proxy := p.createProxy(routeConfig, target)
handler := auth.Wrap(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.oidcHandler)
if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() {
var authType string
if routeConfig.AuthConfig.BasicAuth != nil {
authType = "basic_auth"
} else if routeConfig.AuthConfig.PIN != nil {
authType = "pin"
} else if routeConfig.AuthConfig.Bearer != nil {
authType = "bearer_jwt"
}
log.WithFields(log.Fields{
"route_id": routeConfig.ID,
"auth_type": authType,
}).Debug("Auth middleware enabled for route")
} else {
log.WithFields(log.Fields{
"route_id": routeConfig.ID,
}).Debug("No authentication configured for route")
}
entries = append(entries, &routeEntry{
routeConfig: routeConfig,
path: routePath,
target: target,
proxy: proxy,
handler: handler,
})
}
// Sort by path specificity (longest first)
sort.Slice(entries, func(i, j int) bool {
pi, pj := entries[i].path, entries[j].path
// Empty string or "/" goes last (catch-all)
if pi == "" || pi == "/" {
return false
}
if pj == "" || pj == "/" {
return true
}
return len(pi) > len(pj)
})
// Find first matching entry
for _, entry := range entries {
if entry.path == "" || entry.path == "/" {
// Catch-all route
return entry
}
if strings.HasPrefix(path, entry.path) {
return entry
}
}
return nil
}
// createProxy creates a reverse proxy for a target with the route's connection
func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy {
targetURL, err := url.Parse("http://" + target)
if err != nil {
log.Errorf("Failed to parse target URL %s: %v", target, err)
return &httputil.ReverseProxy{
Director: func(req *http.Request) {},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "Bad Gateway", http.StatusBadGateway)
},
}
}
proxy := httputil.NewSingleHostReverseProxy(targetURL)
proxy.Transport = &http.Transport{
DialContext: routeConfig.nbClient.DialContext,
MaxIdleConns: 1,
MaxIdleConnsPerHost: 1,
IdleConnTimeout: 0,
DisableKeepAlives: false,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
}
return proxy
}
// handleOIDCCallback handles the global /auth/callback endpoint for all routes
func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
if p.oidcHandler == nil {
log.Error("OIDC callback received but no OIDC handler configured")
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
return
}
handler := p.oidcHandler.HandleCallback()
handler(w, r)
}
// extractSourceIP extracts the source IP from the request
func extractSourceIP(r *http.Request) string {
// Try X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Try X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}