mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
using go http reverseproxy with OIDC auth
This commit is contained in:
613
proxy/internal/reverseproxy/auth.go
Normal file
613
proxy/internal/reverseproxy/auth.go
Normal file
@@ -0,0 +1,613 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
const (
|
||||
// Default values for authentication
|
||||
defaultSessionCookieName = "auth_session"
|
||||
defaultPINHeader = "X-PIN"
|
||||
|
||||
// OIDC state expiration time
|
||||
oidcStateExpiration = 10 * time.Minute
|
||||
|
||||
// Error messages
|
||||
errInternalServer = "Internal Server Error"
|
||||
)
|
||||
|
||||
// Global state store for OIDC flow (state -> original URL)
|
||||
var (
|
||||
oidcStateStore = &stateStore{
|
||||
states: make(map[string]*oidcState),
|
||||
}
|
||||
)
|
||||
|
||||
type stateStore struct {
|
||||
mu sync.RWMutex
|
||||
states map[string]*oidcState
|
||||
}
|
||||
|
||||
type oidcState struct {
|
||||
originalURL string
|
||||
createdAt time.Time
|
||||
routeID string
|
||||
}
|
||||
|
||||
func (s *stateStore) Store(state, originalURL, routeID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.states[state] = &oidcState{
|
||||
originalURL: originalURL,
|
||||
createdAt: time.Now(),
|
||||
routeID: routeID,
|
||||
}
|
||||
|
||||
// Clean up expired states
|
||||
cutoff := time.Now().Add(-oidcStateExpiration)
|
||||
for k, v := range s.states {
|
||||
if v.createdAt.Before(cutoff) {
|
||||
delete(s.states, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *stateStore) Get(state string) (*oidcState, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
st, ok := s.states[state]
|
||||
return st, ok
|
||||
}
|
||||
|
||||
func (s *stateStore) Delete(state string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.states, state)
|
||||
}
|
||||
|
||||
// AuthConfig holds the authentication configuration for a route
|
||||
// Only ONE auth method should be configured per route
|
||||
type AuthConfig struct {
|
||||
// HTTP Basic authentication (username/password)
|
||||
BasicAuth *BasicAuthConfig
|
||||
|
||||
// PIN authentication
|
||||
PIN *PINConfig
|
||||
|
||||
// Bearer token with JWT validation and OAuth/OIDC flow
|
||||
// When enabled, uses the global OIDCConfig from proxy Config
|
||||
Bearer *BearerConfig
|
||||
}
|
||||
|
||||
// BasicAuthConfig holds HTTP Basic authentication settings
|
||||
type BasicAuthConfig struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
// PINConfig holds PIN authentication settings
|
||||
type PINConfig struct {
|
||||
PIN string
|
||||
Header string // Header name (default: "X-PIN")
|
||||
}
|
||||
|
||||
// BearerConfig holds JWT/OAuth/OIDC bearer token authentication settings
|
||||
// The actual OIDC/JWT configuration comes from the global proxy Config.OIDCConfig
|
||||
// This just enables Bearer auth for a specific route
|
||||
type BearerConfig struct {
|
||||
// Enable bearer token authentication for this route
|
||||
// Uses the global OIDC configuration from proxy Config
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// IsEmpty returns true if no auth methods are configured
|
||||
func (c *AuthConfig) IsEmpty() bool {
|
||||
if c == nil {
|
||||
return true
|
||||
}
|
||||
return c.BasicAuth == nil && c.PIN == nil && c.Bearer == nil
|
||||
}
|
||||
|
||||
// authMiddlewareHandler is a static middleware that checks AuthConfig
|
||||
type authMiddlewareHandler struct {
|
||||
next http.Handler
|
||||
authConfig *AuthConfig
|
||||
routeID string
|
||||
rejectResponse func(w http.ResponseWriter, r *http.Request)
|
||||
oidcConfig *OIDCConfig // Global OIDC configuration from proxy
|
||||
jwtValidator *jwt.Validator // JWT validator instance (lazily initialized)
|
||||
validatorMu sync.Mutex // Mutex for thread-safe validator initialization
|
||||
}
|
||||
|
||||
func (h *authMiddlewareHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// If no auth configured, allow request
|
||||
if h.authConfig.IsEmpty() {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": h.routeID,
|
||||
"auth_method": "none",
|
||||
"path": r.URL.Path,
|
||||
}).Debug("No authentication configured, allowing request")
|
||||
r.Header.Set("X-Auth-Method", "none")
|
||||
h.next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
var authMethod string
|
||||
var userID string
|
||||
authenticated := false
|
||||
|
||||
// 1. Check Basic Auth
|
||||
if h.authConfig.BasicAuth != nil {
|
||||
if auth := r.Header.Get("Authorization"); auth != "" && strings.HasPrefix(auth, "Basic ") {
|
||||
encoded := strings.TrimPrefix(auth, "Basic ")
|
||||
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||
credentials := string(decoded)
|
||||
parts := strings.SplitN(credentials, ":", 2)
|
||||
if len(parts) == 2 {
|
||||
username, password := parts[0], parts[1]
|
||||
if username == h.authConfig.BasicAuth.Username && password == h.authConfig.BasicAuth.Password {
|
||||
authenticated = true
|
||||
authMethod = "basic"
|
||||
userID = username
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check PIN (if not already authenticated)
|
||||
if !authenticated && h.authConfig.PIN != nil {
|
||||
headerName := h.authConfig.PIN.Header
|
||||
if headerName == "" {
|
||||
headerName = defaultPINHeader
|
||||
}
|
||||
if pin := r.Header.Get(headerName); pin != "" {
|
||||
if pin == h.authConfig.PIN.PIN {
|
||||
authenticated = true
|
||||
authMethod = "pin"
|
||||
userID = "pin_user" // PIN doesn't have a specific user ID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Check Bearer Token with JWT validation (if not already authenticated)
|
||||
if !authenticated && h.authConfig.Bearer != nil && h.oidcConfig != nil {
|
||||
cookieName := h.oidcConfig.SessionCookieName
|
||||
if cookieName == "" {
|
||||
cookieName = defaultSessionCookieName
|
||||
}
|
||||
|
||||
// First, check if there's an _auth_token query parameter (from callback redirect)
|
||||
// This allows us to set the cookie for the current domain
|
||||
if authToken := r.URL.Query().Get("_auth_token"); authToken != "" {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": h.routeID,
|
||||
"host": r.Host,
|
||||
}).Info("Found auth token in query parameter, setting cookie and redirecting")
|
||||
|
||||
// Validate the token before setting cookie
|
||||
if h.validateJWT(authToken) {
|
||||
// Set cookie for current domain
|
||||
cookie := &http.Cookie{
|
||||
Name: cookieName,
|
||||
Value: authToken,
|
||||
Path: "/",
|
||||
MaxAge: 3600, // 1 hour
|
||||
HttpOnly: true,
|
||||
Secure: false, // Set to false for HTTP testing, true for HTTPS in production
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
}
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
// Redirect to same URL without the token parameter
|
||||
cleanURL := *r.URL
|
||||
q := cleanURL.Query()
|
||||
q.Del("_auth_token")
|
||||
cleanURL.RawQuery = q.Encode()
|
||||
|
||||
// Build full URL with scheme and host
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
redirectURL := fmt.Sprintf("%s://%s%s", scheme, r.Host, cleanURL.String())
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": h.routeID,
|
||||
"redirect_url": redirectURL,
|
||||
}).Debug("Redirecting to clean URL after setting cookie")
|
||||
|
||||
http.Redirect(w, r, redirectURL, http.StatusFound)
|
||||
return
|
||||
} else {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": h.routeID,
|
||||
}).Warn("Invalid token in query parameter")
|
||||
}
|
||||
}
|
||||
|
||||
// Check if we have an existing session cookie (from OIDC flow)
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": h.routeID,
|
||||
"cookie_name": cookieName,
|
||||
"host": r.Host,
|
||||
"path": r.URL.Path,
|
||||
}).Debug("Checking for session cookie")
|
||||
|
||||
if cookie, err := r.Cookie(cookieName); err == nil && cookie.Value != "" {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": h.routeID,
|
||||
"cookie_name": cookieName,
|
||||
}).Debug("Session cookie found, validating JWT")
|
||||
|
||||
// Validate the JWT token from the session cookie
|
||||
if h.validateJWT(cookie.Value) {
|
||||
authenticated = true
|
||||
authMethod = "bearer_session"
|
||||
userID = h.extractUserIDFromJWT(cookie.Value)
|
||||
} else {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": h.routeID,
|
||||
}).Debug("JWT validation failed for session cookie")
|
||||
}
|
||||
} else {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": h.routeID,
|
||||
"error": err,
|
||||
}).Debug("No session cookie found")
|
||||
}
|
||||
|
||||
// If no session cookie or validation failed, check Authorization header
|
||||
if !authenticated {
|
||||
if auth := r.Header.Get("Authorization"); auth != "" && strings.HasPrefix(auth, "Bearer ") {
|
||||
token := strings.TrimPrefix(auth, "Bearer ")
|
||||
// Validate JWT token from Authorization header
|
||||
if h.validateJWT(token) {
|
||||
authenticated = true
|
||||
authMethod = "bearer"
|
||||
userID = h.extractUserIDFromJWT(token)
|
||||
}
|
||||
} else {
|
||||
// No bearer token and no valid session - redirect to OIDC provider
|
||||
if h.oidcConfig.ProviderURL != "" {
|
||||
// Initiate OAuth/OIDC flow
|
||||
h.redirectToOIDC(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reject if authentication failed
|
||||
if !authenticated {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": h.routeID,
|
||||
"path": r.URL.Path,
|
||||
"source_ip": extractSourceIP(r),
|
||||
}).Warn("Authentication failed")
|
||||
|
||||
// Call custom reject response or use default
|
||||
if h.rejectResponse != nil {
|
||||
h.rejectResponse(w, r)
|
||||
} else {
|
||||
// Default: return 401 with WWW-Authenticate header
|
||||
w.Header().Set("WWW-Authenticate", `Bearer realm="Restricted"`)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": h.routeID,
|
||||
"auth_method": authMethod,
|
||||
"user_id": userID,
|
||||
"path": r.URL.Path,
|
||||
}).Debug("Authentication successful")
|
||||
|
||||
// Store auth info in headers for logging
|
||||
r.Header.Set("X-Auth-Method", authMethod)
|
||||
r.Header.Set("X-Auth-User-ID", userID)
|
||||
|
||||
// Continue to next handler
|
||||
h.next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// redirectToOIDC initiates the OAuth/OIDC authentication flow
|
||||
func (h *authMiddlewareHandler) redirectToOIDC(w http.ResponseWriter, r *http.Request) {
|
||||
// Generate random state for CSRF protection
|
||||
state, err := generateRandomString(32)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to generate OIDC state")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Store state with original URL for redirect after auth
|
||||
// Include the full URL with scheme and host
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
originalURL := fmt.Sprintf("%s://%s%s", scheme, r.Host, r.URL.String())
|
||||
oidcStateStore.Store(state, originalURL, h.routeID)
|
||||
|
||||
// Default scopes if not configured
|
||||
scopes := h.oidcConfig.Scopes
|
||||
if len(scopes) == 0 {
|
||||
scopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
|
||||
// Build authorization URL
|
||||
authURL, err := url.Parse(h.oidcConfig.ProviderURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Invalid OIDC provider URL")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Append /authorize if it doesn't exist (common OIDC endpoint)
|
||||
if !strings.HasSuffix(authURL.Path, "/authorize") && !strings.HasSuffix(authURL.Path, "/auth") {
|
||||
authURL.Path = strings.TrimSuffix(authURL.Path, "/") + "/authorize"
|
||||
}
|
||||
|
||||
// Build query parameters
|
||||
params := url.Values{}
|
||||
params.Set("client_id", h.oidcConfig.ClientID)
|
||||
params.Set("redirect_uri", h.oidcConfig.RedirectURL)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("scope", strings.Join(scopes, " "))
|
||||
params.Set("state", state)
|
||||
|
||||
// Add audience parameter to get an access token for the API
|
||||
// This ensures we get a proper JWT for the API audience, not just an ID token
|
||||
if len(h.oidcConfig.JWTAudience) > 0 && h.oidcConfig.JWTAudience[0] != h.oidcConfig.ClientID {
|
||||
params.Set("audience", h.oidcConfig.JWTAudience[0])
|
||||
}
|
||||
|
||||
authURL.RawQuery = params.Encode()
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": h.routeID,
|
||||
"provider_url": authURL.String(),
|
||||
"redirect_url": h.oidcConfig.RedirectURL,
|
||||
"state": state,
|
||||
}).Info("Redirecting to OIDC provider for authentication")
|
||||
|
||||
// Redirect user to identity provider login page
|
||||
http.Redirect(w, r, authURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// generateRandomString generates a cryptographically secure random string
|
||||
func generateRandomString(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
|
||||
}
|
||||
|
||||
// HandleOIDCCallback handles the callback from the OIDC provider
|
||||
// This should be registered as a route handler for the callback URL
|
||||
func HandleOIDCCallback(oidcConfig *OIDCConfig) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get authorization code and state from query parameters
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
log.Error("Missing code or state in OIDC callback")
|
||||
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify state to prevent CSRF
|
||||
oidcSt, ok := oidcStateStore.Get(state)
|
||||
if !ok {
|
||||
log.Error("Invalid or expired OIDC state")
|
||||
http.Error(w, "Invalid or expired state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete state to prevent reuse
|
||||
oidcStateStore.Delete(state)
|
||||
|
||||
// Exchange authorization code for token
|
||||
token, err := exchangeCodeForToken(code, oidcConfig)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to exchange code for token")
|
||||
http.Error(w, "Authentication failed", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the original URL to add the token as a query parameter
|
||||
origURL, err := url.Parse(oidcSt.originalURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to parse original URL")
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Add token as query parameter so the original domain can set its own cookie
|
||||
// We use a special parameter name that the auth middleware will look for
|
||||
q := origURL.Query()
|
||||
q.Set("_auth_token", token)
|
||||
origURL.RawQuery = q.Encode()
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": oidcSt.routeID,
|
||||
"original_url": oidcSt.originalURL,
|
||||
"redirect_url": origURL.String(),
|
||||
"callback_host": r.Host,
|
||||
}).Info("OIDC authentication successful, redirecting with token parameter")
|
||||
|
||||
// Redirect back to original URL with token parameter
|
||||
http.Redirect(w, r, origURL.String(), http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges an authorization code for an access token
|
||||
func exchangeCodeForToken(code string, config *OIDCConfig) (string, error) {
|
||||
// Build token endpoint URL
|
||||
tokenURL, err := url.Parse(config.ProviderURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid OIDC provider URL: %w", err)
|
||||
}
|
||||
|
||||
// Auth0 uses /oauth/token, standard OIDC uses /token
|
||||
// Check if path already contains token endpoint
|
||||
if !strings.Contains(tokenURL.Path, "/token") {
|
||||
tokenURL.Path = strings.TrimSuffix(tokenURL.Path, "/") + "/oauth/token"
|
||||
}
|
||||
|
||||
// Build request body
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", config.RedirectURL)
|
||||
data.Set("client_id", config.ClientID)
|
||||
|
||||
// Only include client_secret if it's provided (not needed for public/SPA clients)
|
||||
if config.ClientSecret != "" {
|
||||
data.Set("client_secret", config.ClientSecret)
|
||||
}
|
||||
|
||||
// Make token exchange request
|
||||
resp, err := http.PostForm(tokenURL.String(), data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("token exchange request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
return "", fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return "", fmt.Errorf("no access token in response")
|
||||
}
|
||||
|
||||
// Return the ID token if available (contains user claims), otherwise access token
|
||||
if tokenResp.IDToken != "" {
|
||||
return tokenResp.IDToken, nil
|
||||
}
|
||||
|
||||
return tokenResp.AccessToken, nil
|
||||
}
|
||||
|
||||
// getOrInitValidator lazily initializes and returns the JWT validator
|
||||
func (h *authMiddlewareHandler) getOrInitValidator() *jwt.Validator {
|
||||
h.validatorMu.Lock()
|
||||
defer h.validatorMu.Unlock()
|
||||
|
||||
if h.jwtValidator == nil {
|
||||
h.jwtValidator = jwt.NewValidator(
|
||||
h.oidcConfig.JWTIssuer,
|
||||
h.oidcConfig.JWTAudience,
|
||||
h.oidcConfig.JWTKeysLocation,
|
||||
h.oidcConfig.JWTIdpSignkeyRefreshEnabled,
|
||||
)
|
||||
}
|
||||
|
||||
return h.jwtValidator
|
||||
}
|
||||
|
||||
// validateJWT validates a JWT token using the handler's JWT validator
|
||||
func (h *authMiddlewareHandler) validateJWT(tokenString string) bool {
|
||||
if h.oidcConfig == nil || h.oidcConfig.JWTKeysLocation == "" {
|
||||
log.Error("JWT validation failed: OIDC config or JWTKeysLocation is missing")
|
||||
return false
|
||||
}
|
||||
|
||||
// Get or initialize validator
|
||||
validator := h.getOrInitValidator()
|
||||
|
||||
// Validate the token
|
||||
ctx := context.Background()
|
||||
parsedToken, err := validator.ValidateAndParse(ctx, tokenString)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("JWT validation failed")
|
||||
// Try to parse token without validation to see what's in it
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) == 3 {
|
||||
// Decode payload (middle part)
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err == nil {
|
||||
log.WithFields(log.Fields{
|
||||
"payload": string(payload),
|
||||
}).Debug("Token payload for debugging")
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Token is valid if parsedToken is not nil and Valid is true
|
||||
return parsedToken != nil && parsedToken.Valid
|
||||
}
|
||||
|
||||
// extractUserIDFromJWT extracts the user ID from a JWT token
|
||||
func (h *authMiddlewareHandler) extractUserIDFromJWT(tokenString string) string {
|
||||
if h.jwtValidator == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse the token
|
||||
ctx := context.Background()
|
||||
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// parsedToken is already *jwtgo.Token from ValidateAndParse
|
||||
// Create extractor to get user auth info
|
||||
extractor := jwt.NewClaimsExtractor()
|
||||
userAuth, err := extractor.ToUserAuth(parsedToken)
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("Failed to extract user ID from JWT")
|
||||
return ""
|
||||
}
|
||||
|
||||
return userAuth.UserId
|
||||
}
|
||||
|
||||
// wrapWithAuth wraps a handler with the static authentication middleware
|
||||
// This ALWAYS runs (even when authConfig is nil or empty)
|
||||
func wrapWithAuth(next http.Handler, authConfig *AuthConfig, routeID string, rejectResponse func(w http.ResponseWriter, r *http.Request), oidcConfig *OIDCConfig) http.Handler {
|
||||
if authConfig == nil {
|
||||
authConfig = &AuthConfig{} // Empty config = no auth
|
||||
}
|
||||
|
||||
return &authMiddlewareHandler{
|
||||
next: next,
|
||||
authConfig: authConfig,
|
||||
routeID: routeID,
|
||||
rejectResponse: rejectResponse,
|
||||
oidcConfig: oidcConfig,
|
||||
}
|
||||
}
|
||||
@@ -1,626 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp/reverseproxy"
|
||||
"github.com/caddyserver/caddy/v2/modules/logging"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// CaddyProxy wraps Caddy's reverse proxy functionality
|
||||
type CaddyProxy struct {
|
||||
config Config
|
||||
mu sync.RWMutex
|
||||
isRunning bool
|
||||
routes map[string]*RouteConfig // key is route ID
|
||||
requestCallback RequestDataCallback
|
||||
// customHandlers stores handlers with custom transports that can't be JSON-serialized
|
||||
// key is "routeID:path" to uniquely identify each handler
|
||||
customHandlers map[string]*reverseproxy.Handler
|
||||
}
|
||||
|
||||
// Config holds the reverse proxy configuration
|
||||
type Config struct {
|
||||
// ListenAddress is the address to listen on
|
||||
ListenAddress string
|
||||
|
||||
// EnableHTTPS enables automatic HTTPS with Let's Encrypt
|
||||
EnableHTTPS bool
|
||||
|
||||
// TLSEmail is the email for Let's Encrypt registration
|
||||
TLSEmail string
|
||||
|
||||
// RequestDataCallback is called for each proxied request with metrics
|
||||
RequestDataCallback RequestDataCallback
|
||||
}
|
||||
|
||||
// RouteConfig defines a routing configuration
|
||||
type RouteConfig struct {
|
||||
// ID is a unique identifier for this route
|
||||
ID string
|
||||
|
||||
// Domain is the domain to listen on (e.g., "example.com" or "*" for all)
|
||||
Domain string
|
||||
|
||||
// PathMappings defines paths that should be forwarded to specific ports
|
||||
// Key is the path prefix (e.g., "/", "/api", "/admin")
|
||||
// Value is the target IP:port (e.g., "192.168.1.100:3000")
|
||||
// Must have at least one entry. Use "/" or "" for the default/catch-all route.
|
||||
PathMappings map[string]string
|
||||
|
||||
// Conn is an optional existing network connection to use for this route
|
||||
// This allows routing through specific tunnels (e.g., WireGuard) per route
|
||||
// If set, this connection will be reused for all requests to this route
|
||||
Conn net.Conn
|
||||
|
||||
// CustomDialer is an optional custom dialer for this specific route
|
||||
// This is used if Conn is not set. It allows using different network connections per route
|
||||
CustomDialer func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// New creates a new Caddy-based reverse proxy
|
||||
func New(config Config) (*CaddyProxy, error) {
|
||||
// Default to port 443 if not specified
|
||||
if config.ListenAddress == "" {
|
||||
config.ListenAddress = ":443"
|
||||
}
|
||||
|
||||
cp := &CaddyProxy{
|
||||
config: config,
|
||||
isRunning: false,
|
||||
routes: make(map[string]*RouteConfig),
|
||||
requestCallback: config.RequestDataCallback,
|
||||
customHandlers: make(map[string]*reverseproxy.Handler),
|
||||
}
|
||||
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
// Start starts the Caddy reverse proxy server
|
||||
func (cp *CaddyProxy) Start() error {
|
||||
cp.mu.Lock()
|
||||
if cp.isRunning {
|
||||
cp.mu.Unlock()
|
||||
return fmt.Errorf("reverse proxy already running")
|
||||
}
|
||||
cp.isRunning = true
|
||||
cp.mu.Unlock()
|
||||
|
||||
// Build Caddy configuration
|
||||
cfg, err := cp.buildCaddyConfig()
|
||||
if err != nil {
|
||||
cp.mu.Lock()
|
||||
cp.isRunning = false
|
||||
cp.mu.Unlock()
|
||||
return fmt.Errorf("failed to build Caddy config: %w", err)
|
||||
}
|
||||
|
||||
// Run Caddy with the configuration
|
||||
err = caddy.Run(cfg)
|
||||
if err != nil {
|
||||
cp.mu.Lock()
|
||||
cp.isRunning = false
|
||||
cp.mu.Unlock()
|
||||
return fmt.Errorf("failed to run Caddy: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Caddy reverse proxy started on %s", cp.config.ListenAddress)
|
||||
log.Infof("Configured %d route(s)", len(cp.routes))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the Caddy reverse proxy
|
||||
func (cp *CaddyProxy) Stop(ctx context.Context) error {
|
||||
cp.mu.Lock()
|
||||
if !cp.isRunning {
|
||||
cp.mu.Unlock()
|
||||
return fmt.Errorf("reverse proxy not running")
|
||||
}
|
||||
cp.mu.Unlock()
|
||||
|
||||
log.Info("Stopping Caddy reverse proxy...")
|
||||
|
||||
// Stop Caddy
|
||||
if err := caddy.Stop(); err != nil {
|
||||
return fmt.Errorf("failed to stop Caddy: %w", err)
|
||||
}
|
||||
|
||||
cp.mu.Lock()
|
||||
cp.isRunning = false
|
||||
cp.mu.Unlock()
|
||||
|
||||
log.Info("Caddy reverse proxy stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildCaddyConfig builds the Caddy configuration
|
||||
func (cp *CaddyProxy) buildCaddyConfig() (*caddy.Config, error) {
|
||||
cp.mu.RLock()
|
||||
defer cp.mu.RUnlock()
|
||||
|
||||
if len(cp.routes) == 0 {
|
||||
// Create a default empty server that returns 404
|
||||
httpServer := &caddyhttp.Server{
|
||||
Listen: []string{cp.config.ListenAddress},
|
||||
Routes: caddyhttp.RouteList{},
|
||||
}
|
||||
|
||||
httpApp := &caddyhttp.App{
|
||||
Servers: map[string]*caddyhttp.Server{
|
||||
"proxy": httpServer,
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &caddy.Config{
|
||||
Admin: &caddy.AdminConfig{
|
||||
Disabled: true,
|
||||
},
|
||||
AppsRaw: caddy.ModuleMap{
|
||||
"http": caddyconfig.JSON(httpApp, nil),
|
||||
},
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Build routes grouped by domain
|
||||
domainRoutes := make(map[string][]caddyhttp.Route)
|
||||
// Track unique service IDs for logger configuration
|
||||
serviceIDs := make(map[string]bool)
|
||||
|
||||
for _, routeConfig := range cp.routes {
|
||||
domain := routeConfig.Domain
|
||||
if domain == "" {
|
||||
domain = "*" // wildcard for all domains
|
||||
}
|
||||
|
||||
// Register callback for this service ID
|
||||
if cp.requestCallback != nil {
|
||||
RegisterCallback(routeConfig.ID, cp.requestCallback)
|
||||
serviceIDs[routeConfig.ID] = true
|
||||
}
|
||||
|
||||
// Sort path mappings by path length (longest first) for proper matching
|
||||
// This ensures more specific paths match before catch-all paths
|
||||
paths := make([]string, 0, len(routeConfig.PathMappings))
|
||||
for path := range routeConfig.PathMappings {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
sort.Slice(paths, func(i, j int) bool {
|
||||
// Sort by length descending, but put empty string last (catch-all)
|
||||
if paths[i] == "" || paths[i] == "/" {
|
||||
return false
|
||||
}
|
||||
if paths[j] == "" || paths[j] == "/" {
|
||||
return true
|
||||
}
|
||||
return len(paths[i]) > len(paths[j])
|
||||
})
|
||||
|
||||
// Create routes for each path mapping
|
||||
for _, path := range paths {
|
||||
target := routeConfig.PathMappings[path]
|
||||
route := cp.createRoute(routeConfig, path, target)
|
||||
domainRoutes[domain] = append(domainRoutes[domain], route)
|
||||
}
|
||||
}
|
||||
|
||||
// Build Caddy routes
|
||||
var caddyRoutes caddyhttp.RouteList
|
||||
for domain, routes := range domainRoutes {
|
||||
if domain != "*" {
|
||||
// Add host matcher for specific domains
|
||||
for i := range routes {
|
||||
routes[i].MatcherSetsRaw = []caddy.ModuleMap{
|
||||
{
|
||||
"host": caddyconfig.JSON(caddyhttp.MatchHost{domain}, nil),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
caddyRoutes = append(caddyRoutes, routes...)
|
||||
}
|
||||
|
||||
// Create HTTP server with access logging if callback is set
|
||||
httpServer := &caddyhttp.Server{
|
||||
Listen: []string{cp.config.ListenAddress},
|
||||
Routes: caddyRoutes,
|
||||
}
|
||||
|
||||
// Configure server logging if callback is set
|
||||
if cp.requestCallback != nil {
|
||||
httpServer.Logs = &caddyhttp.ServerLogConfig{
|
||||
// Use our custom logger for access logs
|
||||
LoggerNames: map[string]caddyhttp.StringArray{
|
||||
"http.log.access": {"http_access"},
|
||||
},
|
||||
// Disable default access logging (only use custom logger)
|
||||
ShouldLogCredentials: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Disable automatic HTTPS if not enabled
|
||||
if !cp.config.EnableHTTPS {
|
||||
// Explicitly disable automatic HTTPS for the server
|
||||
httpServer.AutoHTTPS = &caddyhttp.AutoHTTPSConfig{
|
||||
Disabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Build HTTP app
|
||||
httpApp := &caddyhttp.App{
|
||||
Servers: map[string]*caddyhttp.Server{
|
||||
"proxy": httpServer,
|
||||
},
|
||||
}
|
||||
|
||||
// Provision the HTTP app to set up handlers from JSON
|
||||
ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
|
||||
defer cancel()
|
||||
|
||||
if err := httpApp.Provision(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to provision HTTP app: %w", err)
|
||||
}
|
||||
|
||||
// After provisioning, inject custom transports into handlers
|
||||
// This is done post-provisioning so the Transport field is preserved
|
||||
if err := cp.injectCustomTransports(httpApp); err != nil {
|
||||
return nil, fmt.Errorf("failed to inject custom transports: %w", err)
|
||||
}
|
||||
|
||||
// Create Caddy config with the provisioned app
|
||||
// IMPORTANT: We pass the already-provisioned app, not JSON
|
||||
// This preserves the Transport fields we set
|
||||
cfg := &caddy.Config{
|
||||
Admin: &caddy.AdminConfig{
|
||||
Disabled: true,
|
||||
},
|
||||
// Apps field takes already-provisioned apps
|
||||
Apps: map[string]caddy.App{
|
||||
"http": httpApp,
|
||||
},
|
||||
}
|
||||
|
||||
// Configure logging if callback is set
|
||||
if cp.requestCallback != nil {
|
||||
// Register the callback for the proxy service ID
|
||||
RegisterCallback("proxy", cp.requestCallback)
|
||||
|
||||
// Build logging config with proper module names
|
||||
cfg.Logging = &caddy.Logging{
|
||||
Logs: map[string]*caddy.CustomLog{
|
||||
"http_access": {
|
||||
BaseLog: caddy.BaseLog{
|
||||
WriterRaw: caddyconfig.JSONModuleObject(&CallbackWriter{ServiceID: "proxy"}, "output", "callback", nil),
|
||||
EncoderRaw: caddyconfig.JSONModuleObject(&logging.JSONEncoder{}, "format", "json", nil),
|
||||
Level: "INFO",
|
||||
},
|
||||
Include: []string{"http.log.access"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
log.Infof("Configured custom logging with callback writer for service: proxy")
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// createRoute creates a Caddy route for a path and target with service ID tracking
|
||||
func (cp *CaddyProxy) createRoute(routeConfig *RouteConfig, path, target string) caddyhttp.Route {
|
||||
// Check if this route needs a custom transport
|
||||
hasCustomTransport := routeConfig.Conn != nil || routeConfig.CustomDialer != nil
|
||||
|
||||
if hasCustomTransport {
|
||||
// For routes with custom transports, store them separately
|
||||
// and configure the upstream to use a special dial address that we'll intercept
|
||||
handlerKey := fmt.Sprintf("%s:%s", routeConfig.ID, path)
|
||||
|
||||
// Create upstream with custom dial configuration
|
||||
upstream := &reverseproxy.Upstream{
|
||||
Dial: target,
|
||||
}
|
||||
|
||||
// Create the reverse proxy handler with custom transport
|
||||
handler := &reverseproxy.Handler{
|
||||
Upstreams: reverseproxy.UpstreamPool{upstream},
|
||||
}
|
||||
|
||||
// Configure the custom transport
|
||||
if routeConfig.Conn != nil {
|
||||
// Use the provided connection directly
|
||||
transport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
log.Debugf("Reusing existing connection for route %s to %s", routeConfig.ID, address)
|
||||
return routeConfig.Conn, nil
|
||||
},
|
||||
MaxIdleConns: 1,
|
||||
MaxIdleConnsPerHost: 1,
|
||||
IdleConnTimeout: 0,
|
||||
DisableKeepAlives: false,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
handler.Transport = transport
|
||||
log.Infof("Configured net.Conn transport for route %s (path: %s)", routeConfig.ID, path)
|
||||
} else if routeConfig.CustomDialer != nil {
|
||||
// Use the custom dialer function
|
||||
transport := &http.Transport{
|
||||
DialContext: routeConfig.CustomDialer,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
handler.Transport = transport
|
||||
log.Infof("Configured custom dialer transport for route %s (path: %s)", routeConfig.ID, path)
|
||||
}
|
||||
|
||||
// Store the handler for later injection
|
||||
cp.customHandlers[handlerKey] = handler
|
||||
|
||||
// Create route using HandlersRaw with a placeholder that will be replaced
|
||||
// We'll use JSON serialization here, but inject the real handler after Caddy loads
|
||||
route := caddyhttp.Route{
|
||||
HandlersRaw: []json.RawMessage{
|
||||
caddyconfig.JSONModuleObject(handler, "handler", "reverse_proxy", nil),
|
||||
},
|
||||
}
|
||||
|
||||
if path != "" {
|
||||
route.MatcherSetsRaw = []caddy.ModuleMap{
|
||||
{
|
||||
"path": caddyconfig.JSON(caddyhttp.MatchPath{path + "*"}, nil),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return route
|
||||
}
|
||||
|
||||
// Standard route without custom transport
|
||||
upstream := &reverseproxy.Upstream{
|
||||
Dial: target,
|
||||
}
|
||||
|
||||
handler := &reverseproxy.Handler{
|
||||
Upstreams: reverseproxy.UpstreamPool{upstream},
|
||||
}
|
||||
|
||||
route := caddyhttp.Route{
|
||||
HandlersRaw: []json.RawMessage{
|
||||
caddyconfig.JSONModuleObject(handler, "handler", "reverse_proxy", nil),
|
||||
},
|
||||
}
|
||||
|
||||
if path != "" {
|
||||
route.MatcherSetsRaw = []caddy.ModuleMap{
|
||||
{
|
||||
"path": caddyconfig.JSON(caddyhttp.MatchPath{path + "*"}, nil),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return route
|
||||
}
|
||||
|
||||
// IsRunning returns whether the proxy is running
|
||||
func (cp *CaddyProxy) IsRunning() bool {
|
||||
cp.mu.RLock()
|
||||
defer cp.mu.RUnlock()
|
||||
return cp.isRunning
|
||||
}
|
||||
|
||||
// GetConfig returns the proxy configuration
|
||||
func (cp *CaddyProxy) GetConfig() Config {
|
||||
return cp.config
|
||||
}
|
||||
|
||||
// AddRoute adds a new route configuration to the proxy
|
||||
// If the proxy is running, it will reload the configuration
|
||||
func (cp *CaddyProxy) AddRoute(route *RouteConfig) error {
|
||||
if route == nil {
|
||||
return fmt.Errorf("route cannot be nil")
|
||||
}
|
||||
if route.ID == "" {
|
||||
return fmt.Errorf("route ID is required")
|
||||
}
|
||||
if len(route.PathMappings) == 0 {
|
||||
return fmt.Errorf("route must have at least one path mapping")
|
||||
}
|
||||
|
||||
cp.mu.Lock()
|
||||
// Check if route already exists
|
||||
if _, exists := cp.routes[route.ID]; exists {
|
||||
cp.mu.Unlock()
|
||||
return fmt.Errorf("route with ID %s already exists", route.ID)
|
||||
}
|
||||
|
||||
// Add new route
|
||||
cp.routes[route.ID] = route
|
||||
isRunning := cp.isRunning
|
||||
cp.mu.Unlock()
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": route.ID,
|
||||
"domain": route.Domain,
|
||||
"paths": len(route.PathMappings),
|
||||
}).Info("Added route")
|
||||
|
||||
// Reload configuration if proxy is running
|
||||
if isRunning {
|
||||
if err := cp.reloadConfig(); err != nil {
|
||||
// Rollback: remove the route
|
||||
cp.mu.Lock()
|
||||
delete(cp.routes, route.ID)
|
||||
cp.mu.Unlock()
|
||||
return fmt.Errorf("failed to reload config after adding route: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoute removes a route from the proxy
|
||||
// If the proxy is running, it will reload the configuration
|
||||
func (cp *CaddyProxy) RemoveRoute(routeID string) error {
|
||||
cp.mu.Lock()
|
||||
// Check if route exists
|
||||
route, exists := cp.routes[routeID]
|
||||
if !exists {
|
||||
cp.mu.Unlock()
|
||||
return fmt.Errorf("route %s not found", routeID)
|
||||
}
|
||||
|
||||
// Remove route
|
||||
delete(cp.routes, routeID)
|
||||
isRunning := cp.isRunning
|
||||
cp.mu.Unlock()
|
||||
|
||||
log.Infof("Removed route: %s", routeID)
|
||||
|
||||
// Reload configuration if proxy is running
|
||||
if isRunning {
|
||||
if err := cp.reloadConfig(); err != nil {
|
||||
// Rollback: add the route back
|
||||
cp.mu.Lock()
|
||||
cp.routes[routeID] = route
|
||||
cp.mu.Unlock()
|
||||
return fmt.Errorf("failed to reload config after removing route: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRoute updates an existing route configuration
|
||||
// If the proxy is running, it will reload the configuration
|
||||
func (cp *CaddyProxy) UpdateRoute(route *RouteConfig) error {
|
||||
if route == nil {
|
||||
return fmt.Errorf("route cannot be nil")
|
||||
}
|
||||
if route.ID == "" {
|
||||
return fmt.Errorf("route ID is required")
|
||||
}
|
||||
|
||||
cp.mu.Lock()
|
||||
// Check if route exists
|
||||
oldRoute, exists := cp.routes[route.ID]
|
||||
if !exists {
|
||||
cp.mu.Unlock()
|
||||
return fmt.Errorf("route %s not found", route.ID)
|
||||
}
|
||||
|
||||
// Update route
|
||||
cp.routes[route.ID] = route
|
||||
isRunning := cp.isRunning
|
||||
cp.mu.Unlock()
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": route.ID,
|
||||
"domain": route.Domain,
|
||||
"paths": len(route.PathMappings),
|
||||
}).Info("Updated route")
|
||||
|
||||
// Reload configuration if proxy is running
|
||||
if isRunning {
|
||||
if err := cp.reloadConfig(); err != nil {
|
||||
// Rollback: restore old route
|
||||
cp.mu.Lock()
|
||||
cp.routes[route.ID] = oldRoute
|
||||
cp.mu.Unlock()
|
||||
return fmt.Errorf("failed to reload config after updating route: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of all configured route IDs
|
||||
func (cp *CaddyProxy) ListRoutes() []string {
|
||||
cp.mu.RLock()
|
||||
defer cp.mu.RUnlock()
|
||||
|
||||
routes := make([]string, 0, len(cp.routes))
|
||||
for id := range cp.routes {
|
||||
routes = append(routes, id)
|
||||
}
|
||||
return routes
|
||||
}
|
||||
|
||||
// GetRoute returns a route configuration by ID
|
||||
func (cp *CaddyProxy) GetRoute(routeID string) (*RouteConfig, error) {
|
||||
cp.mu.RLock()
|
||||
defer cp.mu.RUnlock()
|
||||
|
||||
route, exists := cp.routes[routeID]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("route %s not found", routeID)
|
||||
}
|
||||
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// injectCustomTransports injects custom transports into provisioned handlers
|
||||
// This must be called after httpApp.Provision() but before passing to Caddy.Run()
|
||||
func (cp *CaddyProxy) injectCustomTransports(httpApp *caddyhttp.App) error {
|
||||
// Iterate through all servers
|
||||
for serverName, server := range httpApp.Servers {
|
||||
log.Debugf("Injecting custom transports for server: %s", serverName)
|
||||
|
||||
// Iterate through all routes
|
||||
for routeIdx, route := range server.Routes {
|
||||
// Iterate through all handlers in the route
|
||||
for handlerIdx, handler := range route.Handlers {
|
||||
// Check if this is a reverse proxy handler
|
||||
if rpHandler, ok := handler.(*reverseproxy.Handler); ok {
|
||||
// Try to find a matching custom handler for this route
|
||||
// We need to match by handler configuration since we don't have route metadata here
|
||||
for handlerKey, customHandler := range cp.customHandlers {
|
||||
// Check if the upstream configuration matches
|
||||
if len(rpHandler.Upstreams) > 0 && len(customHandler.Upstreams) > 0 {
|
||||
if rpHandler.Upstreams[0].Dial == customHandler.Upstreams[0].Dial {
|
||||
// Match found! Inject the custom transport
|
||||
rpHandler.Transport = customHandler.Transport
|
||||
log.Infof("Injected custom transport for route %d, handler %d (key: %s)", routeIdx, handlerIdx, handlerKey)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reloadConfig rebuilds and reloads the Caddy configuration
|
||||
// Must be called without holding the lock
|
||||
func (cp *CaddyProxy) reloadConfig() error {
|
||||
log.Info("Reloading Caddy configuration...")
|
||||
|
||||
cfg, err := cp.buildCaddyConfig()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build config: %w", err)
|
||||
}
|
||||
|
||||
if err := caddy.Run(cfg); err != nil {
|
||||
return fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
log.Info("Caddy configuration reloaded successfully")
|
||||
return nil
|
||||
}
|
||||
@@ -1,225 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/caddyserver/caddy/v2"
|
||||
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
// Global map to store callbacks per service ID
|
||||
callbackRegistry = make(map[string]RequestDataCallback)
|
||||
callbackMu sync.RWMutex
|
||||
)
|
||||
|
||||
// RegisterCallback registers a callback for a specific service ID
|
||||
func RegisterCallback(serviceID string, callback RequestDataCallback) {
|
||||
callbackMu.Lock()
|
||||
defer callbackMu.Unlock()
|
||||
callbackRegistry[serviceID] = callback
|
||||
}
|
||||
|
||||
// UnregisterCallback removes a callback for a specific service ID
|
||||
func UnregisterCallback(serviceID string) {
|
||||
callbackMu.Lock()
|
||||
defer callbackMu.Unlock()
|
||||
delete(callbackRegistry, serviceID)
|
||||
}
|
||||
|
||||
// getCallback retrieves the callback for a service ID
|
||||
func getCallback(serviceID string) RequestDataCallback {
|
||||
callbackMu.RLock()
|
||||
defer callbackMu.RUnlock()
|
||||
return callbackRegistry[serviceID]
|
||||
}
|
||||
|
||||
func init() {
|
||||
caddy.RegisterModule(CallbackWriter{})
|
||||
}
|
||||
|
||||
// CallbackWriter is a Caddy log writer module that sends request data via callback
|
||||
type CallbackWriter struct {
|
||||
ServiceID string `json:"service_id,omitempty"`
|
||||
}
|
||||
|
||||
// CaddyModule returns the Caddy module information
|
||||
func (CallbackWriter) CaddyModule() caddy.ModuleInfo {
|
||||
return caddy.ModuleInfo{
|
||||
ID: "caddy.logging.writers.callback",
|
||||
New: func() caddy.Module { return new(CallbackWriter) },
|
||||
}
|
||||
}
|
||||
|
||||
// Provision sets up the callback writer
|
||||
func (cw *CallbackWriter) Provision(ctx caddy.Context) error {
|
||||
log.Infof("CallbackWriter.Provision called for service_id: %s", cw.ServiceID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of the writer
|
||||
func (cw *CallbackWriter) String() string {
|
||||
return fmt.Sprintf("callback writer for service %s", cw.ServiceID)
|
||||
}
|
||||
|
||||
// WriterKey returns a unique key for this writer configuration
|
||||
func (cw *CallbackWriter) WriterKey() string {
|
||||
return "callback_" + cw.ServiceID
|
||||
}
|
||||
|
||||
// OpenWriter opens the writer
|
||||
func (cw *CallbackWriter) OpenWriter() (io.WriteCloser, error) {
|
||||
log.Infof("CallbackWriter.OpenWriter called for service_id: %s", cw.ServiceID)
|
||||
writer := &LogWriter{
|
||||
serviceID: cw.ServiceID,
|
||||
}
|
||||
log.Infof("Created LogWriter instance: %p for service_id: %s", writer, cw.ServiceID)
|
||||
return writer, nil
|
||||
}
|
||||
|
||||
// UnmarshalCaddyfile implements caddyfile.Unmarshaler
|
||||
func (cw *CallbackWriter) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
for d.Next() {
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
cw.ServiceID = d.Val()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure CallbackWriter implements the required interfaces
|
||||
var (
|
||||
_ caddy.Provisioner = (*CallbackWriter)(nil)
|
||||
_ caddy.WriterOpener = (*CallbackWriter)(nil)
|
||||
_ caddyfile.Unmarshaler = (*CallbackWriter)(nil)
|
||||
)
|
||||
|
||||
// LogWriter is a custom io.Writer that parses Caddy's structured JSON logs
|
||||
// and extracts request metrics to send via callback
|
||||
type LogWriter struct {
|
||||
serviceID string
|
||||
}
|
||||
|
||||
// NewLogWriter creates a new log writer with the given service ID
|
||||
func NewLogWriter(serviceID string) *LogWriter {
|
||||
return &LogWriter{
|
||||
serviceID: serviceID,
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer
|
||||
func (lw *LogWriter) Write(p []byte) (n int, err error) {
|
||||
// DEBUG: Log that we received data
|
||||
log.Infof("LogWriter.Write called with %d bytes for service_id: %s", len(p), lw.serviceID)
|
||||
log.Debugf("LogWriter content: %s", string(p))
|
||||
|
||||
// Caddy writes one JSON object per line
|
||||
// Parse the JSON to extract request metrics
|
||||
var logEntry map[string]interface{}
|
||||
if err := json.Unmarshal(p, &logEntry); err != nil {
|
||||
// Not JSON or malformed, skip
|
||||
log.Debugf("Failed to unmarshal JSON: %v", err)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Caddy access logs have a nested "request" object
|
||||
// Check if this is an access log entry by looking for "request" field
|
||||
requestObj, hasRequest := logEntry["request"]
|
||||
if !hasRequest {
|
||||
log.Debugf("Not an access log entry (no 'request' field)")
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
request, ok := requestObj.(map[string]interface{})
|
||||
if !ok {
|
||||
log.Debugf("'request' field is not a map")
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Extract fields
|
||||
data := &RequestData{
|
||||
ServiceID: lw.serviceID,
|
||||
}
|
||||
|
||||
// Extract method from request object
|
||||
if method, ok := request["method"].(string); ok {
|
||||
data.Method = method
|
||||
}
|
||||
|
||||
// Extract host from request object and strip port
|
||||
if host, ok := request["host"].(string); ok {
|
||||
// Strip port from host (e.g., "test.netbird.io:54321" -> "test.netbird.io")
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
data.Host = host[:idx]
|
||||
} else {
|
||||
data.Host = host
|
||||
}
|
||||
}
|
||||
|
||||
// Extract path (uri field) from request object
|
||||
if uri, ok := request["uri"].(string); ok {
|
||||
data.Path = uri
|
||||
}
|
||||
|
||||
// Extract status code from top-level
|
||||
if status, ok := logEntry["status"].(float64); ok {
|
||||
data.ResponseCode = int32(status)
|
||||
}
|
||||
|
||||
// Extract duration (in seconds, convert to milliseconds) from top-level
|
||||
if duration, ok := logEntry["duration"].(float64); ok {
|
||||
data.DurationMs = int64(duration * 1000)
|
||||
}
|
||||
|
||||
// Extract source IP from request object - try multiple fields
|
||||
if clientIP, ok := request["client_ip"].(string); ok {
|
||||
data.SourceIP = clientIP
|
||||
} else if remoteIP, ok := request["remote_ip"].(string); ok {
|
||||
data.SourceIP = remoteIP
|
||||
} else if remoteAddr, ok := request["remote_addr"].(string); ok {
|
||||
// remote_addr is in "IP:port" format
|
||||
if idx := strings.LastIndex(remoteAddr, ":"); idx != -1 {
|
||||
data.SourceIP = remoteAddr[:idx]
|
||||
} else {
|
||||
data.SourceIP = remoteAddr
|
||||
}
|
||||
}
|
||||
|
||||
// Call callback if set and we have valid data
|
||||
callback := getCallback(lw.serviceID)
|
||||
if callback != nil && data.Method != "" {
|
||||
log.Infof("Calling callback for request: %s %s", data.Method, data.Path)
|
||||
go func() {
|
||||
// Run in goroutine to avoid blocking log writes
|
||||
callback(data)
|
||||
}()
|
||||
} else {
|
||||
log.Warnf("No callback registered for service_id: %s", lw.serviceID)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"service_id": data.ServiceID,
|
||||
"method": data.Method,
|
||||
"host": data.Host,
|
||||
"path": data.Path,
|
||||
"status": data.ResponseCode,
|
||||
"duration_ms": data.DurationMs,
|
||||
"source_ip": data.SourceIP,
|
||||
}).Info("Request logged via callback writer")
|
||||
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Close implements io.Closer (no-op for our use case)
|
||||
func (lw *LogWriter) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure LogWriter implements io.WriteCloser
|
||||
var _ io.WriteCloser = (*LogWriter)(nil)
|
||||
@@ -1,251 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLogWriter_Write(t *testing.T) {
|
||||
// Create a channel to receive callback data
|
||||
callbackChan := make(chan *RequestData, 1)
|
||||
var callbackMu sync.Mutex
|
||||
var callbackCalled bool
|
||||
|
||||
// Register a test callback
|
||||
testServiceID := "test-service"
|
||||
RegisterCallback(testServiceID, func(data *RequestData) {
|
||||
callbackMu.Lock()
|
||||
callbackCalled = true
|
||||
callbackMu.Unlock()
|
||||
callbackChan <- data
|
||||
})
|
||||
defer UnregisterCallback(testServiceID)
|
||||
|
||||
// Create a log writer
|
||||
writer := NewLogWriter(testServiceID)
|
||||
|
||||
// Create a sample Caddy access log entry (matching the structure from your logs)
|
||||
logEntry := map[string]interface{}{
|
||||
"level": "info",
|
||||
"ts": 1768352053.7900746,
|
||||
"logger": "http.log.access",
|
||||
"msg": "handled request",
|
||||
"request": map[string]interface{}{
|
||||
"remote_ip": "::1",
|
||||
"remote_port": "51972",
|
||||
"client_ip": "::1",
|
||||
"proto": "HTTP/1.1",
|
||||
"method": "GET",
|
||||
"host": "test.netbird.io:54321",
|
||||
"uri": "/test/path",
|
||||
},
|
||||
"bytes_read": 0,
|
||||
"user_id": "",
|
||||
"duration": 0.004779453,
|
||||
"size": 615,
|
||||
"status": 200,
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
logJSON, err := json.Marshal(logEntry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal log entry: %v", err)
|
||||
}
|
||||
|
||||
// Write to the log writer
|
||||
n, err := writer.Write(logJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
|
||||
if n != len(logJSON) {
|
||||
t.Errorf("Expected to write %d bytes, wrote %d", len(logJSON), n)
|
||||
}
|
||||
|
||||
// Wait for callback to be called (with timeout)
|
||||
select {
|
||||
case data := <-callbackChan:
|
||||
// Verify the extracted data
|
||||
if data.ServiceID != testServiceID {
|
||||
t.Errorf("Expected service_id %s, got %s", testServiceID, data.ServiceID)
|
||||
}
|
||||
if data.Method != "GET" {
|
||||
t.Errorf("Expected method GET, got %s", data.Method)
|
||||
}
|
||||
if data.Host != "test.netbird.io" {
|
||||
t.Errorf("Expected host test.netbird.io, got %s", data.Host)
|
||||
}
|
||||
if data.Path != "/test/path" {
|
||||
t.Errorf("Expected path /test/path, got %s", data.Path)
|
||||
}
|
||||
if data.ResponseCode != 200 {
|
||||
t.Errorf("Expected status 200, got %d", data.ResponseCode)
|
||||
}
|
||||
if data.SourceIP != "::1" {
|
||||
t.Errorf("Expected source_ip ::1, got %s", data.SourceIP)
|
||||
}
|
||||
// Duration should be ~4.78ms (0.004779453 * 1000)
|
||||
if data.DurationMs < 4 || data.DurationMs > 5 {
|
||||
t.Errorf("Expected duration ~4-5ms, got %dms", data.DurationMs)
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("Callback was not called within timeout")
|
||||
}
|
||||
|
||||
// Verify callback was called
|
||||
callbackMu.Lock()
|
||||
defer callbackMu.Unlock()
|
||||
if !callbackCalled {
|
||||
t.Error("Callback was never called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWriter_Write_NonAccessLog(t *testing.T) {
|
||||
// Create a channel to receive callback data
|
||||
callbackChan := make(chan *RequestData, 1)
|
||||
|
||||
// Register a test callback
|
||||
testServiceID := "test-service-2"
|
||||
RegisterCallback(testServiceID, func(data *RequestData) {
|
||||
callbackChan <- data
|
||||
})
|
||||
defer UnregisterCallback(testServiceID)
|
||||
|
||||
// Create a log writer
|
||||
writer := NewLogWriter(testServiceID)
|
||||
|
||||
// Create a non-access log entry (e.g., a TLS log)
|
||||
logEntry := map[string]interface{}{
|
||||
"level": "info",
|
||||
"ts": 1768352032.12347,
|
||||
"logger": "tls",
|
||||
"msg": "storage cleaning happened too recently",
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
logJSON, err := json.Marshal(logEntry)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal log entry: %v", err)
|
||||
}
|
||||
|
||||
// Write to the log writer
|
||||
n, err := writer.Write(logJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
|
||||
if n != len(logJSON) {
|
||||
t.Errorf("Expected to write %d bytes, wrote %d", len(logJSON), n)
|
||||
}
|
||||
|
||||
// Callback should NOT be called for non-access logs
|
||||
select {
|
||||
case data := <-callbackChan:
|
||||
t.Errorf("Callback should not be called for non-access log, but got: %+v", data)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
// Expected - callback not called
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWriter_Write_MalformedJSON(t *testing.T) {
|
||||
// Create a log writer
|
||||
writer := NewLogWriter("test-service-3")
|
||||
|
||||
// Write malformed JSON
|
||||
malformedJSON := []byte("{this is not valid json")
|
||||
|
||||
// Should not fail, just skip the entry
|
||||
n, err := writer.Write(malformedJSON)
|
||||
if err != nil {
|
||||
t.Fatalf("Write should not fail on malformed JSON: %v", err)
|
||||
}
|
||||
|
||||
if n != len(malformedJSON) {
|
||||
t.Errorf("Expected to write %d bytes, wrote %d", len(malformedJSON), n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbackRegistry(t *testing.T) {
|
||||
serviceID := "test-registry"
|
||||
var called bool
|
||||
|
||||
// Test registering a callback
|
||||
callback := func(data *RequestData) {
|
||||
called = true
|
||||
}
|
||||
RegisterCallback(serviceID, callback)
|
||||
|
||||
// Test retrieving the callback
|
||||
retrievedCallback := getCallback(serviceID)
|
||||
if retrievedCallback == nil {
|
||||
t.Fatal("Expected to retrieve callback, got nil")
|
||||
}
|
||||
|
||||
// Call the retrieved callback to verify it works
|
||||
retrievedCallback(&RequestData{})
|
||||
if !called {
|
||||
t.Error("Callback was not called")
|
||||
}
|
||||
|
||||
// Test unregistering
|
||||
UnregisterCallback(serviceID)
|
||||
retrievedCallback = getCallback(serviceID)
|
||||
if retrievedCallback != nil {
|
||||
t.Error("Expected nil after unregistering, got a callback")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbackWriter_Module(t *testing.T) {
|
||||
// Test that the module is properly configured
|
||||
cw := CallbackWriter{ServiceID: "test"}
|
||||
|
||||
moduleInfo := cw.CaddyModule()
|
||||
if moduleInfo.ID != "caddy.logging.writers.callback" {
|
||||
t.Errorf("Expected module ID 'caddy.logging.writers.callback', got '%s'", moduleInfo.ID)
|
||||
}
|
||||
|
||||
if moduleInfo.New == nil {
|
||||
t.Error("Expected New function to be set")
|
||||
}
|
||||
|
||||
// Test creating a new instance via the New function
|
||||
newModule := moduleInfo.New()
|
||||
if newModule == nil {
|
||||
t.Error("Expected New() to return a module instance")
|
||||
}
|
||||
|
||||
_, ok := newModule.(*CallbackWriter)
|
||||
if !ok {
|
||||
t.Error("Expected New() to return a *CallbackWriter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbackWriter_WriterKey(t *testing.T) {
|
||||
cw := &CallbackWriter{ServiceID: "my-service"}
|
||||
|
||||
expectedKey := "callback_my-service"
|
||||
if cw.WriterKey() != expectedKey {
|
||||
t.Errorf("Expected writer key '%s', got '%s'", expectedKey, cw.WriterKey())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallbackWriter_String(t *testing.T) {
|
||||
cw := &CallbackWriter{ServiceID: "my-service"}
|
||||
|
||||
str := cw.String()
|
||||
if str != "callback writer for service my-service" {
|
||||
t.Errorf("Unexpected string representation: %s", str)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogWriter_Close(t *testing.T) {
|
||||
writer := NewLogWriter("test")
|
||||
|
||||
// Close should not fail
|
||||
err := writer.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Close should not fail: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,7 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// RequestDataCallback is called for each request that passes through the proxy
|
||||
type RequestDataCallback func(data *RequestData)
|
||||
type RequestDataCallback func(data RequestData)
|
||||
|
||||
// RequestData contains metadata about a proxied request
|
||||
type RequestData struct {
|
||||
@@ -21,111 +12,8 @@ type RequestData struct {
|
||||
Method string
|
||||
ResponseCode int32
|
||||
SourceIP string
|
||||
}
|
||||
|
||||
// MetricsMiddleware wraps a handler to capture request metrics
|
||||
type MetricsMiddleware struct {
|
||||
Next caddyhttp.Handler
|
||||
ServiceID string
|
||||
Callback RequestDataCallback
|
||||
}
|
||||
|
||||
// ServeHTTP implements caddyhttp.MiddlewareHandler
|
||||
func (m *MetricsMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
|
||||
// Record start time
|
||||
startTime := time.Now()
|
||||
|
||||
// Wrap the response writer to capture status code
|
||||
wrappedWriter := &responseWriterWrapper{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK, // Default to 200
|
||||
}
|
||||
|
||||
// Call the next handler (Caddy's reverse proxy)
|
||||
err := next.ServeHTTP(wrappedWriter, r)
|
||||
|
||||
// Calculate duration
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// Extract source IP (handle X-Forwarded-For or direct connection)
|
||||
sourceIP := extractSourceIP(r)
|
||||
|
||||
// Create request data
|
||||
data := &RequestData{
|
||||
ServiceID: m.ServiceID,
|
||||
Path: r.URL.Path,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Method: r.Method,
|
||||
ResponseCode: int32(wrappedWriter.statusCode),
|
||||
SourceIP: sourceIP,
|
||||
}
|
||||
|
||||
// Call callback if set
|
||||
if m.Callback != nil {
|
||||
go func() {
|
||||
// Run callback in goroutine to avoid blocking response
|
||||
m.Callback(data)
|
||||
}()
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"service_id": data.ServiceID,
|
||||
"method": data.Method,
|
||||
"path": data.Path,
|
||||
"status": data.ResponseCode,
|
||||
"duration_ms": data.DurationMs,
|
||||
"source_ip": data.SourceIP,
|
||||
}).Debug("Request proxied")
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// responseWriterWrapper wraps http.ResponseWriter to capture status code
|
||||
type responseWriterWrapper struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
written bool
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code
|
||||
func (w *responseWriterWrapper) WriteHeader(statusCode int) {
|
||||
if !w.written {
|
||||
w.statusCode = statusCode
|
||||
w.written = true
|
||||
}
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
// Write ensures we capture status if WriteHeader wasn't called explicitly
|
||||
func (w *responseWriterWrapper) Write(b []byte) (int, error) {
|
||||
if !w.written {
|
||||
w.written = true
|
||||
// Status code defaults to 200 if not explicitly set
|
||||
}
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
// extractSourceIP extracts the real client IP from the request
|
||||
func extractSourceIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header first (if behind a proxy)
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// X-Forwarded-For can be a comma-separated list, take the first one
|
||||
parts := strings.Split(xff, ",")
|
||||
if len(parts) > 0 {
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Check X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
// RemoteAddr is in format "IP:port", so we need to strip the port
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
return r.RemoteAddr[:idx]
|
||||
}
|
||||
|
||||
return r.RemoteAddr
|
||||
|
||||
AuthMechanism string
|
||||
UserID string
|
||||
AuthSuccess bool
|
||||
}
|
||||
|
||||
817
proxy/internal/reverseproxy/proxy.go
Normal file
817
proxy/internal/reverseproxy/proxy.go
Normal file
@@ -0,0 +1,817 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Proxy wraps a reverse proxy with dynamic routing
|
||||
type Proxy struct {
|
||||
config Config
|
||||
mu sync.RWMutex
|
||||
routes map[string]*RouteConfig // key is host/domain (for fast O(1) lookup)
|
||||
server *http.Server
|
||||
httpServer *http.Server
|
||||
autocertManager *autocert.Manager
|
||||
isRunning bool
|
||||
requestCallback RequestDataCallback
|
||||
}
|
||||
|
||||
// Config holds the reverse proxy configuration
|
||||
type Config struct {
|
||||
// ListenAddress is the address to listen on for HTTPS (default ":443")
|
||||
ListenAddress string
|
||||
|
||||
// HTTPListenAddress is the address for HTTP (default ":80")
|
||||
// Used for ACME challenges when HTTPS is enabled, or as main listener when HTTPS is disabled
|
||||
HTTPListenAddress string
|
||||
|
||||
// EnableHTTPS enables automatic HTTPS with Let's Encrypt
|
||||
EnableHTTPS bool
|
||||
|
||||
// TLSEmail is the email for Let's Encrypt registration
|
||||
TLSEmail string
|
||||
|
||||
// CertCacheDir is the directory to cache certificates (default "./certs")
|
||||
CertCacheDir string
|
||||
|
||||
// RequestDataCallback is called for each proxied request with metrics
|
||||
RequestDataCallback RequestDataCallback
|
||||
|
||||
// OIDCConfig is the global OIDC/OAuth configuration for authentication
|
||||
// This is shared across all routes that use Bearer authentication
|
||||
// If nil, routes with Bearer auth will fail to initialize
|
||||
OIDCConfig *OIDCConfig
|
||||
}
|
||||
|
||||
// OIDCConfig holds the global OIDC/OAuth configuration
|
||||
type OIDCConfig struct {
|
||||
// OIDC Provider settings
|
||||
ProviderURL string `env:"NB_OIDC_PROVIDER_URL" json:"provider_url"` // Identity provider URL (e.g., "https://accounts.google.com")
|
||||
ClientID string `env:"NB_OIDC_CLIENT_ID" json:"client_id"` // OAuth client ID
|
||||
ClientSecret string `env:"NB_OIDC_CLIENT_SECRET" json:"client_secret"` // OAuth client secret (empty for public clients)
|
||||
RedirectURL string `env:"NB_OIDC_REDIRECT_URL" json:"redirect_url"` // Redirect URL after auth (e.g., "http://localhost:54321/auth/callback")
|
||||
Scopes []string `env:"NB_OIDC_SCOPES" json:"scopes"` // Requested scopes (default: ["openid", "profile", "email"])
|
||||
|
||||
// JWT Validation settings
|
||||
JWTKeysLocation string `env:"NB_OIDC_JWT_KEYS_LOCATION" json:"jwt_keys_location"` // JWKS URL for fetching public keys
|
||||
JWTIssuer string `env:"NB_OIDC_JWT_ISSUER" json:"jwt_issuer"` // Expected issuer claim
|
||||
JWTAudience []string `env:"NB_OIDC_JWT_AUDIENCE" json:"jwt_audience"` // Expected audience claims
|
||||
JWTIdpSignkeyRefreshEnabled bool `env:"NB_OIDC_JWT_IDP_SIGNKEY_REFRESH_ENABLED" json:"jwt_idp_signkey_refresh_enabled"` // Enable automatic refresh of signing keys
|
||||
|
||||
// Session settings
|
||||
SessionCookieName string `env:"NB_OIDC_SESSION_COOKIE_NAME" json:"session_cookie_name"` // Cookie name for storing session (default: "auth_session")
|
||||
}
|
||||
|
||||
// RouteConfig defines a routing configuration
|
||||
type RouteConfig struct {
|
||||
// ID is a unique identifier for this route
|
||||
ID string
|
||||
|
||||
// Domain is the domain to listen on (e.g., "example.com" or "*" for all)
|
||||
Domain string
|
||||
|
||||
// PathMappings defines paths that should be forwarded to specific ports
|
||||
// Key is the path prefix (e.g., "/", "/api", "/admin")
|
||||
// Value is the target IP:port (e.g., "192.168.1.100:3000")
|
||||
// Must have at least one entry. Use "/" or "" for the default/catch-all route.
|
||||
PathMappings map[string]string
|
||||
|
||||
// Conn is the network connection to use for this route
|
||||
// This allows routing through specific tunnels (e.g., WireGuard) per route
|
||||
// This connection will be reused for all requests to this route
|
||||
Conn net.Conn
|
||||
|
||||
// AuthConfig is optional authentication configuration for this route
|
||||
// Configure ONE of: BasicAuth, PIN, or Bearer (JWT/OIDC)
|
||||
// If nil, requests pass through without authentication
|
||||
AuthConfig *AuthConfig
|
||||
|
||||
// AuthRejectResponse is an optional custom response for authentication failures
|
||||
// If nil, returns 401 Unauthorized with WWW-Authenticate header
|
||||
AuthRejectResponse func(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// routeEntry represents a compiled route with its proxy
|
||||
type routeEntry struct {
|
||||
routeConfig *RouteConfig
|
||||
path string
|
||||
target string
|
||||
proxy *httputil.ReverseProxy
|
||||
handler http.Handler // handler wraps proxy with middleware (auth, logging, etc.)
|
||||
}
|
||||
|
||||
// New creates a new reverse proxy
|
||||
func New(config Config) (*Proxy, error) {
|
||||
// Set defaults
|
||||
if config.ListenAddress == "" {
|
||||
config.ListenAddress = ":443"
|
||||
}
|
||||
if config.HTTPListenAddress == "" {
|
||||
config.HTTPListenAddress = ":80"
|
||||
}
|
||||
if config.CertCacheDir == "" {
|
||||
config.CertCacheDir = "./certs"
|
||||
}
|
||||
|
||||
// Validate HTTPS config
|
||||
if config.EnableHTTPS {
|
||||
if config.TLSEmail == "" {
|
||||
return nil, fmt.Errorf("TLSEmail is required when EnableHTTPS is true")
|
||||
}
|
||||
}
|
||||
|
||||
// Set default OIDC session cookie name if not provided
|
||||
if config.OIDCConfig != nil && config.OIDCConfig.SessionCookieName == "" {
|
||||
config.OIDCConfig.SessionCookieName = "auth_session"
|
||||
}
|
||||
|
||||
p := &Proxy{
|
||||
config: config,
|
||||
routes: make(map[string]*RouteConfig),
|
||||
isRunning: false,
|
||||
requestCallback: config.RequestDataCallback,
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Start starts the reverse proxy server
|
||||
func (p *Proxy) Start() error {
|
||||
p.mu.Lock()
|
||||
if p.isRunning {
|
||||
p.mu.Unlock()
|
||||
return fmt.Errorf("reverse proxy already running")
|
||||
}
|
||||
p.isRunning = true
|
||||
p.mu.Unlock()
|
||||
|
||||
// Build the main HTTP handler
|
||||
handler := p.buildHandler()
|
||||
|
||||
if p.config.EnableHTTPS {
|
||||
// Setup autocert manager with dynamic host policy
|
||||
p.autocertManager = &autocert.Manager{
|
||||
Cache: autocert.DirCache(p.config.CertCacheDir),
|
||||
Prompt: autocert.AcceptTOS,
|
||||
Email: p.config.TLSEmail,
|
||||
HostPolicy: p.dynamicHostPolicy, // Use dynamic policy based on routes
|
||||
}
|
||||
|
||||
// Start HTTP server for ACME challenges
|
||||
p.httpServer = &http.Server{
|
||||
Addr: p.config.HTTPListenAddress,
|
||||
Handler: p.autocertManager.HTTPHandler(nil),
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Infof("Starting HTTP server on %s for ACME challenges", p.config.HTTPListenAddress)
|
||||
if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Errorf("HTTP server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Start HTTPS server
|
||||
p.server = &http.Server{
|
||||
Addr: p.config.ListenAddress,
|
||||
Handler: handler,
|
||||
TLSConfig: p.autocertManager.TLSConfig(),
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Infof("Starting HTTPS server on %s", p.config.ListenAddress)
|
||||
if err := p.server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
|
||||
log.Errorf("HTTPS server error: %v", err)
|
||||
p.mu.Lock()
|
||||
p.isRunning = false
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
// Start HTTP server only
|
||||
p.server = &http.Server{
|
||||
Addr: p.config.HTTPListenAddress,
|
||||
Handler: handler,
|
||||
}
|
||||
|
||||
go func() {
|
||||
log.Infof("Starting HTTP server on %s", p.config.HTTPListenAddress)
|
||||
if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
log.Errorf("HTTP server error: %v", err)
|
||||
p.mu.Lock()
|
||||
p.isRunning = false
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
log.Infof("Reverse proxy started with %d route(s)", len(p.routes))
|
||||
return nil
|
||||
}
|
||||
|
||||
// dynamicHostPolicy is a custom host policy that allows certificates for any domain
|
||||
// that has a configured route
|
||||
func (p *Proxy) dynamicHostPolicy(ctx context.Context, host string) error {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
// Strip port if present
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
// O(1) lookup for exact domain match
|
||||
if _, exists := p.routes[host]; exists {
|
||||
log.Infof("Allowing certificate for domain: %s", host)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Warnf("Rejecting certificate request for unknown domain: %s", host)
|
||||
return fmt.Errorf("domain %s not configured in routes", host)
|
||||
}
|
||||
|
||||
// Stop gracefully stops the reverse proxy
|
||||
func (p *Proxy) Stop(ctx context.Context) error {
|
||||
p.mu.Lock()
|
||||
if !p.isRunning {
|
||||
p.mu.Unlock()
|
||||
return fmt.Errorf("reverse proxy not running")
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
log.Info("Stopping reverse proxy...")
|
||||
|
||||
// Stop HTTPS server
|
||||
if p.server != nil {
|
||||
if err := p.server.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("failed to shutdown HTTPS server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop HTTP server (ACME challenge server)
|
||||
if p.httpServer != nil {
|
||||
if err := p.httpServer.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("failed to shutdown HTTP server: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
p.isRunning = false
|
||||
p.mu.Unlock()
|
||||
|
||||
log.Info("Reverse proxy stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildHandler creates the main HTTP handler with router for static endpoints
|
||||
func (p *Proxy) buildHandler() http.Handler {
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Register static endpoints
|
||||
router.HandleFunc("/auth/callback", p.handleOIDCCallback).Methods("GET")
|
||||
|
||||
// Catch-all handler for dynamic proxy routing
|
||||
router.PathPrefix("/").HandlerFunc(p.handleProxyRequest)
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
// handleProxyRequest handles all dynamic proxy requests
|
||||
func (p *Proxy) handleProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := time.Now()
|
||||
|
||||
routeEntry := p.findRoute(r.Host, r.URL.Path)
|
||||
if routeEntry == nil {
|
||||
log.Warnf("No route found for host=%s path=%s", r.Host, r.URL.Path)
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
rw := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
|
||||
routeEntry.handler.ServeHTTP(rw, r)
|
||||
|
||||
if p.requestCallback != nil {
|
||||
duration := time.Since(startTime)
|
||||
|
||||
host := r.Host
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
authMechanism := r.Header.Get("X-Auth-Method")
|
||||
if authMechanism == "" {
|
||||
authMechanism = "none"
|
||||
}
|
||||
|
||||
// Determine auth success based on status code
|
||||
authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden
|
||||
|
||||
// Extract user ID (this would need to be enhanced to extract from tokens/headers)
|
||||
_, userID, _ := extractAuthInfo(r, rw.statusCode)
|
||||
|
||||
data := RequestData{
|
||||
ServiceID: routeEntry.routeConfig.ID,
|
||||
Host: host,
|
||||
Path: r.URL.Path,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
Method: r.Method,
|
||||
ResponseCode: int32(rw.statusCode),
|
||||
SourceIP: extractSourceIP(r),
|
||||
AuthMechanism: authMechanism,
|
||||
UserID: userID,
|
||||
AuthSuccess: authSuccess,
|
||||
}
|
||||
|
||||
p.requestCallback(data)
|
||||
}
|
||||
}
|
||||
|
||||
// findRoute finds the matching route for a given host and path
|
||||
func (p *Proxy) findRoute(host, path string) *routeEntry {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
// Strip port from host
|
||||
if idx := strings.LastIndex(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
// O(1) lookup by host
|
||||
routeConfig, exists := p.routes[host]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build list of route entries sorted by path specificity
|
||||
var entries []*routeEntry
|
||||
|
||||
// Create entries for each path mapping
|
||||
for routePath, target := range routeConfig.PathMappings {
|
||||
proxy := p.createProxy(routeConfig, target)
|
||||
|
||||
// ALWAYS wrap proxy with auth middleware (even if no auth configured)
|
||||
// This ensures consistent auth handling and logging
|
||||
handler := wrapWithAuth(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.config.OIDCConfig)
|
||||
|
||||
// Log auth configuration
|
||||
if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() {
|
||||
var authType string
|
||||
if routeConfig.AuthConfig.BasicAuth != nil {
|
||||
authType = "basic_auth"
|
||||
} else if routeConfig.AuthConfig.PIN != nil {
|
||||
authType = "pin"
|
||||
} else if routeConfig.AuthConfig.Bearer != nil {
|
||||
authType = "bearer_jwt"
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": routeConfig.ID,
|
||||
"auth_type": authType,
|
||||
}).Debug("Auth middleware enabled for route")
|
||||
} else {
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": routeConfig.ID,
|
||||
}).Debug("No authentication configured for route")
|
||||
}
|
||||
|
||||
entries = append(entries, &routeEntry{
|
||||
routeConfig: routeConfig,
|
||||
path: routePath,
|
||||
target: target,
|
||||
proxy: proxy,
|
||||
handler: handler,
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by path specificity (longest first)
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
pi, pj := entries[i].path, entries[j].path
|
||||
// Empty string or "/" goes last (catch-all)
|
||||
if pi == "" || pi == "/" {
|
||||
return false
|
||||
}
|
||||
if pj == "" || pj == "/" {
|
||||
return true
|
||||
}
|
||||
return len(pi) > len(pj)
|
||||
})
|
||||
|
||||
// Find first matching entry
|
||||
for _, entry := range entries {
|
||||
if entry.path == "" || entry.path == "/" {
|
||||
// Catch-all route
|
||||
return entry
|
||||
}
|
||||
if strings.HasPrefix(path, entry.path) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createProxy creates a reverse proxy for a target with the route's connection
|
||||
func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy {
|
||||
// Parse target URL
|
||||
targetURL, err := url.Parse("http://" + target)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse target URL %s: %v", target, err)
|
||||
// Return a proxy that returns 502
|
||||
return &httputil.ReverseProxy{
|
||||
Director: func(req *http.Request) {},
|
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Create reverse proxy
|
||||
proxy := httputil.NewSingleHostReverseProxy(targetURL)
|
||||
|
||||
// Check if this is a defaultConn (for testing)
|
||||
if dc, ok := routeConfig.Conn.(*defaultConn); ok {
|
||||
// For defaultConn, use its dialer directly
|
||||
proxy.Transport = &http.Transport{
|
||||
DialContext: dc.dialer.DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
log.Infof("Using default network dialer for route %s (testing mode)", routeConfig.ID)
|
||||
} else {
|
||||
// Configure transport to use the provided connection (WireGuard, etc.)
|
||||
proxy.Transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
log.Debugf("Using custom connection for route %s to %s", routeConfig.ID, address)
|
||||
return routeConfig.Conn, nil
|
||||
},
|
||||
MaxIdleConns: 1,
|
||||
MaxIdleConnsPerHost: 1,
|
||||
IdleConnTimeout: 0, // Keep alive indefinitely
|
||||
DisableKeepAlives: false,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
log.Infof("Using custom connection for route %s", routeConfig.ID)
|
||||
}
|
||||
|
||||
// Custom error handler
|
||||
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
|
||||
log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err)
|
||||
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||
}
|
||||
|
||||
return proxy
|
||||
}
|
||||
|
||||
// AddRoute adds a new route configuration
|
||||
func (p *Proxy) AddRoute(route *RouteConfig) error {
|
||||
if route == nil {
|
||||
return fmt.Errorf("route cannot be nil")
|
||||
}
|
||||
if route.ID == "" {
|
||||
return fmt.Errorf("route ID is required")
|
||||
}
|
||||
if route.Domain == "" {
|
||||
return fmt.Errorf("route Domain is required")
|
||||
}
|
||||
if len(route.PathMappings) == 0 {
|
||||
return fmt.Errorf("route must have at least one path mapping")
|
||||
}
|
||||
if route.Conn == nil {
|
||||
return fmt.Errorf("route connection (Conn) is required")
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Check if route already exists for this domain
|
||||
if _, exists := p.routes[route.Domain]; exists {
|
||||
return fmt.Errorf("route for domain %s already exists", route.Domain)
|
||||
}
|
||||
|
||||
// Add route with domain as key
|
||||
p.routes[route.Domain] = route
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": route.ID,
|
||||
"domain": route.Domain,
|
||||
"paths": len(route.PathMappings),
|
||||
}).Info("Added route")
|
||||
|
||||
// Note: With this architecture, we don't need to reload the server
|
||||
// The handler dynamically looks up routes on each request
|
||||
// Certificates will be obtained automatically when the domain is first accessed
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoute removes a route
|
||||
func (p *Proxy) RemoveRoute(domain string) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Check if route exists
|
||||
if _, exists := p.routes[domain]; !exists {
|
||||
return fmt.Errorf("route for domain %s not found", domain)
|
||||
}
|
||||
|
||||
// Remove route
|
||||
delete(p.routes, domain)
|
||||
|
||||
log.Infof("Removed route for domain: %s", domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRoute updates an existing route
|
||||
func (p *Proxy) UpdateRoute(route *RouteConfig) error {
|
||||
if route == nil {
|
||||
return fmt.Errorf("route cannot be nil")
|
||||
}
|
||||
if route.ID == "" {
|
||||
return fmt.Errorf("route ID is required")
|
||||
}
|
||||
if route.Domain == "" {
|
||||
return fmt.Errorf("route Domain is required")
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// Check if route exists for this domain
|
||||
if _, exists := p.routes[route.Domain]; !exists {
|
||||
return fmt.Errorf("route for domain %s not found", route.Domain)
|
||||
}
|
||||
|
||||
// Update route using domain as key
|
||||
p.routes[route.Domain] = route
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"route_id": route.ID,
|
||||
"domain": route.Domain,
|
||||
"paths": len(route.PathMappings),
|
||||
}).Info("Updated route")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of all configured domains
|
||||
func (p *Proxy) ListRoutes() []string {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
domains := make([]string, 0, len(p.routes))
|
||||
for domain := range p.routes {
|
||||
domains = append(domains, domain)
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// GetRoute returns a route configuration by domain
|
||||
func (p *Proxy) GetRoute(domain string) (*RouteConfig, error) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
route, exists := p.routes[domain]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("route for domain %s not found", domain)
|
||||
}
|
||||
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// IsRunning returns whether the proxy is running
|
||||
func (p *Proxy) IsRunning() bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
return p.isRunning
|
||||
}
|
||||
|
||||
// GetConfig returns the proxy configuration
|
||||
func (p *Proxy) GetConfig() Config {
|
||||
return p.config
|
||||
}
|
||||
|
||||
// responseWriter wraps http.ResponseWriter to capture status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// extractSourceIP extracts the source IP from the request
|
||||
func extractSourceIP(r *http.Request) string {
|
||||
// Try X-Forwarded-For header first
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP in the list
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
|
||||
// Try X-Real-IP header
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
return r.RemoteAddr[:idx]
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
// extractAuthInfo extracts authentication information from the request
|
||||
// Returns: authMechanism, userID, authSuccess
|
||||
func extractAuthInfo(r *http.Request, statusCode int) (string, string, bool) {
|
||||
// Check if authentication succeeded based on status code
|
||||
// 401 = Unauthorized, 403 = Forbidden
|
||||
authSuccess := statusCode != http.StatusUnauthorized && statusCode != http.StatusForbidden
|
||||
|
||||
// Check for Bearer token (JWT, OAuth2, etc.)
|
||||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
// Extract user ID from JWT if possible (you may want to decode the JWT here)
|
||||
// For now, we'll just indicate it's a bearer token
|
||||
return "bearer", extractUserIDFromBearer(auth), authSuccess
|
||||
}
|
||||
if strings.HasPrefix(auth, "Basic ") {
|
||||
// Basic authentication
|
||||
return "basic", extractUserIDFromBasic(auth), authSuccess
|
||||
}
|
||||
// Other authorization schemes
|
||||
return "other", "", authSuccess
|
||||
}
|
||||
|
||||
// Check for API key in headers
|
||||
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
|
||||
return "api_key", "", authSuccess
|
||||
}
|
||||
if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" {
|
||||
return "api_key", "", authSuccess
|
||||
}
|
||||
|
||||
// Check for mutual TLS (client certificate)
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
// Extract Common Name from client certificate
|
||||
cn := r.TLS.PeerCertificates[0].Subject.CommonName
|
||||
return "mtls", cn, authSuccess
|
||||
}
|
||||
|
||||
// Check for session cookie (common in web apps)
|
||||
if cookie, err := r.Cookie("session"); err == nil && cookie.Value != "" {
|
||||
return "session", "", authSuccess
|
||||
}
|
||||
|
||||
// No authentication detected
|
||||
return "none", "", authSuccess
|
||||
}
|
||||
|
||||
// extractUserIDFromBearer attempts to extract user ID from Bearer token
|
||||
// Decodes the JWT (without verification) to extract the user ID from standard claims
|
||||
func extractUserIDFromBearer(auth string) string {
|
||||
// Remove "Bearer " prefix
|
||||
tokenString := strings.TrimPrefix(auth, "Bearer ")
|
||||
if tokenString == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// JWT format: header.payload.signature
|
||||
// We only need the payload to extract user ID (no verification needed here)
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
log.Debug("Invalid JWT format: expected 3 parts")
|
||||
return ""
|
||||
}
|
||||
|
||||
// Decode the payload (second part)
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
log.WithError(err).Debug("Failed to decode JWT payload")
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse JSON payload
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
log.WithError(err).Debug("Failed to parse JWT claims")
|
||||
return ""
|
||||
}
|
||||
|
||||
// Try standard user ID claims in order of preference
|
||||
// 1. "sub" (standard JWT subject claim)
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
|
||||
// 2. "user_id" (common in some systems)
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
|
||||
// 3. "email" (fallback)
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
// 4. "preferred_username" (used by some OIDC providers)
|
||||
if username, ok := claims["preferred_username"].(string); ok && username != "" {
|
||||
return username
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractUserIDFromBasic extracts username from Basic auth header
|
||||
func extractUserIDFromBasic(auth string) string {
|
||||
// Basic auth format: "Basic base64(username:password)"
|
||||
_ = strings.TrimPrefix(auth, "Basic ")
|
||||
// Note: We're not decoding it here for security reasons
|
||||
// The upstream service should handle the actual authentication
|
||||
// We just note that basic auth was used
|
||||
return ""
|
||||
}
|
||||
|
||||
// defaultConn is a lazy connection wrapper that uses the standard network dialer
|
||||
// This is useful for testing or development when not using WireGuard tunnels
|
||||
type defaultConn struct {
|
||||
dialer *net.Dialer
|
||||
mu sync.Mutex
|
||||
conns map[string]net.Conn // cache connections by "network:address"
|
||||
}
|
||||
|
||||
func (dc *defaultConn) Read(b []byte) (n int, err error) {
|
||||
return 0, fmt.Errorf("Read not supported on defaultConn - use dial via Transport")
|
||||
}
|
||||
|
||||
func (dc *defaultConn) Write(b []byte) (n int, err error) {
|
||||
return 0, fmt.Errorf("Write not supported on defaultConn - use dial via Transport")
|
||||
}
|
||||
|
||||
func (dc *defaultConn) Close() error {
|
||||
dc.mu.Lock()
|
||||
defer dc.mu.Unlock()
|
||||
|
||||
for _, conn := range dc.conns {
|
||||
conn.Close()
|
||||
}
|
||||
dc.conns = make(map[string]net.Conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dc *defaultConn) LocalAddr() net.Addr { return nil }
|
||||
func (dc *defaultConn) RemoteAddr() net.Addr { return nil }
|
||||
func (dc *defaultConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (dc *defaultConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (dc *defaultConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
// NewDefaultConn creates a connection wrapper that uses the standard network dialer
|
||||
// This is useful for testing or development when not using WireGuard tunnels
|
||||
// The actual dialing happens when the HTTP Transport calls DialContext
|
||||
func NewDefaultConn() net.Conn {
|
||||
return &defaultConn{
|
||||
dialer: &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
},
|
||||
conns: make(map[string]net.Conn),
|
||||
}
|
||||
}
|
||||
|
||||
// handleOIDCCallback handles the global /auth/callback endpoint for all routes
|
||||
func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if OIDC is configured globally
|
||||
if p.config.OIDCConfig == nil {
|
||||
log.Error("OIDC callback received but no OIDC config found")
|
||||
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Use the HandleOIDCCallback function from auth.go with global config
|
||||
handler := HandleOIDCCallback(p.config.OIDCConfig)
|
||||
handler(w, r)
|
||||
}
|
||||
@@ -1,139 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// customTransportRegistry stores custom dialers and connections globally
|
||||
// This allows them to be accessed after Caddy deserializes the configuration from JSON
|
||||
var customTransportRegistry = &transportRegistry{
|
||||
transports: make(map[string]*customTransport),
|
||||
}
|
||||
|
||||
// transportRegistry manages custom transports for routes
|
||||
type transportRegistry struct {
|
||||
mu sync.RWMutex
|
||||
transports map[string]*customTransport // key is "routeID:path"
|
||||
}
|
||||
|
||||
// customTransport wraps either a net.Conn or a custom dialer
|
||||
type customTransport struct {
|
||||
routeID string
|
||||
path string
|
||||
conn net.Conn
|
||||
customDialer func(ctx context.Context, network, address string) (net.Conn, error)
|
||||
defaultDialer *net.Dialer
|
||||
}
|
||||
|
||||
// Register registers a custom transport for a route
|
||||
func (r *transportRegistry) Register(routeID, path string, conn net.Conn, dialer func(ctx context.Context, network, address string) (net.Conn, error)) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
key := fmt.Sprintf("%s:%s", routeID, path)
|
||||
r.transports[key] = &customTransport{
|
||||
routeID: routeID,
|
||||
path: path,
|
||||
conn: conn,
|
||||
customDialer: dialer,
|
||||
defaultDialer: &net.Dialer{Timeout: 30 * time.Second},
|
||||
}
|
||||
|
||||
if conn != nil {
|
||||
log.Infof("Registered net.Conn transport for route %s (path: %s)", routeID, path)
|
||||
} else if dialer != nil {
|
||||
log.Infof("Registered custom dialer transport for route %s (path: %s)", routeID, path)
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a custom transport for a route
|
||||
func (r *transportRegistry) Get(routeID, path string) *customTransport {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
key := fmt.Sprintf("%s:%s", routeID, path)
|
||||
return r.transports[key]
|
||||
}
|
||||
|
||||
// Unregister removes a custom transport
|
||||
func (r *transportRegistry) Unregister(routeID, path string) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
key := fmt.Sprintf("%s:%s", routeID, path)
|
||||
delete(r.transports, key)
|
||||
log.Infof("Unregistered transport for route %s (path: %s)", routeID, path)
|
||||
}
|
||||
|
||||
// Clear removes all custom transports
|
||||
func (r *transportRegistry) Clear() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.transports = make(map[string]*customTransport)
|
||||
log.Info("Cleared all custom transports")
|
||||
}
|
||||
|
||||
// DialContext implements the DialContext function for custom transports
|
||||
func (ct *customTransport) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
// If we have a pre-existing connection, return it
|
||||
if ct.conn != nil {
|
||||
log.Debugf("Reusing existing connection for route %s (path: %s) to %s", ct.routeID, ct.path, address)
|
||||
return ct.conn, nil
|
||||
}
|
||||
|
||||
// If we have a custom dialer, use it
|
||||
if ct.customDialer != nil {
|
||||
log.Debugf("Using custom dialer for route %s (path: %s) to %s", ct.routeID, ct.path, address)
|
||||
return ct.customDialer(ctx, network, address)
|
||||
}
|
||||
|
||||
// Fallback to default dialer (this shouldn't happen if registered correctly)
|
||||
log.Warnf("No custom transport found for route %s (path: %s), using default dialer", ct.routeID, ct.path)
|
||||
return ct.defaultDialer.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// NewCustomHTTPTransport creates an HTTP transport that uses the custom dialer
|
||||
func NewCustomHTTPTransport(routeID, path string) *http.Transport {
|
||||
transport := customTransportRegistry.Get(routeID, path)
|
||||
if transport == nil {
|
||||
// No custom transport registered, return standard transport
|
||||
log.Warnf("No custom transport found for route %s (path: %s), using standard transport", routeID, path)
|
||||
return &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Configure transport based on whether we're using a connection or dialer
|
||||
if transport.conn != nil {
|
||||
// Using a pre-existing connection - disable pooling
|
||||
return &http.Transport{
|
||||
DialContext: transport.DialContext,
|
||||
MaxIdleConns: 1,
|
||||
MaxIdleConnsPerHost: 1,
|
||||
IdleConnTimeout: 0, // Keep alive indefinitely
|
||||
DisableKeepAlives: false,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Using a custom dialer - use normal pooling
|
||||
return &http.Transport{
|
||||
DialContext: transport.DialContext,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user