This commit is contained in:
pascal
2026-01-15 14:54:33 +01:00
parent 12b38e25da
commit ed5f98da5b
22 changed files with 1511 additions and 1392 deletions

View File

@@ -0,0 +1,25 @@
package auth
import "github.com/netbirdio/netbird/proxy/internal/auth/methods"
// Config holds the authentication configuration for a route
// Only ONE auth method should be configured per route
type Config struct {
// HTTP Basic authentication (username/password)
BasicAuth *methods.BasicAuthConfig
// PIN authentication
PIN *methods.PINConfig
// Bearer token with JWT validation and OAuth/OIDC flow
// When enabled, uses the global OIDCConfig from proxy Config
Bearer *methods.BearerConfig
}
// IsEmpty returns true if no auth methods are configured
func (c *Config) IsEmpty() bool {
if c == nil {
return true
}
return c.BasicAuth == nil && c.PIN == nil && c.Bearer == nil
}

View File

@@ -0,0 +1,9 @@
package auth
const (
// DefaultSessionCookieName is the default cookie name for session storage
DefaultSessionCookieName = "auth_session"
// ErrorInternalServer is the default internal server error message
ErrorInternalServer = "Internal Server Error"
)

View File

@@ -0,0 +1,26 @@
package methods
import (
"crypto/subtle"
"net/http"
)
// BasicAuthConfig holds HTTP Basic authentication settings
type BasicAuthConfig struct {
Username string
Password string
}
// Validate checks Basic Auth credentials from the request
func (c *BasicAuthConfig) Validate(r *http.Request) bool {
username, password, ok := r.BasicAuth()
if !ok {
return false
}
// Use constant-time comparison to prevent timing attacks
usernameMatch := subtle.ConstantTimeCompare([]byte(username), []byte(c.Username)) == 1
passwordMatch := subtle.ConstantTimeCompare([]byte(password), []byte(c.Password)) == 1
return usernameMatch && passwordMatch
}

View File

@@ -0,0 +1,10 @@
package methods
// BearerConfig holds JWT/OAuth/OIDC bearer token authentication settings
// The actual OIDC/JWT configuration comes from the global proxy Config.OIDCConfig
// This just enables Bearer auth for a specific route
type BearerConfig struct {
// Enable bearer token authentication for this route
// Uses the global OIDC configuration from proxy Config
Enabled bool
}

View File

@@ -0,0 +1,33 @@
package methods
import (
"crypto/subtle"
"net/http"
)
const (
// DefaultPINHeader is the default header name for PIN authentication
DefaultPINHeader = "X-PIN"
)
// PINConfig holds PIN authentication settings
type PINConfig struct {
PIN string
Header string // Header name (default: "X-PIN")
}
// Validate checks PIN from the request header
func (c *PINConfig) Validate(r *http.Request) bool {
header := c.Header
if header == "" {
header = DefaultPINHeader
}
providedPIN := r.Header.Get(header)
if providedPIN == "" {
return false
}
// Use constant-time comparison to prevent timing attacks
return subtle.ConstantTimeCompare([]byte(providedPIN), []byte(c.PIN)) == 1
}

View File

@@ -0,0 +1,312 @@
package auth
import (
"fmt"
"net/http"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/auth/oidc"
)
// Middleware wraps an HTTP handler with authentication middleware
type Middleware struct {
next http.Handler
config *Config
routeID string
rejectResponse func(w http.ResponseWriter, r *http.Request)
oidcHandler *oidc.Handler // OIDC handler for OAuth flow (contains config and JWT validator)
}
// authResult holds the result of an authentication attempt
type authResult struct {
authenticated bool
method string
userID string
}
// ServeHTTP implements the http.Handler interface
func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// If no auth configured, allow request
if m.config.IsEmpty() {
m.allowWithoutAuth(w, r)
return
}
// Try to authenticate the request
result := m.authenticate(w, r)
if result == nil {
// Authentication triggered a redirect (e.g., OIDC flow)
return
}
// Reject if authentication failed
if !result.authenticated {
m.rejectRequest(w, r)
return
}
// Authentication successful - continue to next handler
m.continueWithAuth(w, r, result)
}
// allowWithoutAuth allows requests when no authentication is configured
func (m *Middleware) allowWithoutAuth(w http.ResponseWriter, r *http.Request) {
log.WithFields(log.Fields{
"route_id": m.routeID,
"auth_method": "none",
"path": r.URL.Path,
}).Debug("No authentication configured, allowing request")
r.Header.Set("X-Auth-Method", "none")
m.next.ServeHTTP(w, r)
}
// authenticate attempts to authenticate the request using configured methods
// Returns nil if a redirect occurred (e.g., OIDC flow initiated)
func (m *Middleware) authenticate(w http.ResponseWriter, r *http.Request) *authResult {
// Try Basic Auth
if result := m.tryBasicAuth(r); result.authenticated {
return result
}
// Try PIN Auth
if result := m.tryPINAuth(r); result.authenticated {
return result
}
// Try Bearer/OIDC Auth
return m.tryBearerAuth(w, r)
}
// tryBasicAuth attempts Basic authentication
func (m *Middleware) tryBasicAuth(r *http.Request) *authResult {
if m.config.BasicAuth == nil {
return &authResult{}
}
if !m.config.BasicAuth.Validate(r) {
return &authResult{}
}
result := &authResult{
authenticated: true,
method: "basic",
}
// Extract username from Basic Auth
if username, _, ok := r.BasicAuth(); ok {
result.userID = username
}
return result
}
// tryPINAuth attempts PIN authentication
func (m *Middleware) tryPINAuth(r *http.Request) *authResult {
if m.config.PIN == nil {
return &authResult{}
}
if !m.config.PIN.Validate(r) {
return &authResult{}
}
return &authResult{
authenticated: true,
method: "pin",
userID: "pin_user",
}
}
// tryBearerAuth attempts Bearer token authentication with JWT validation
// Returns nil if OIDC redirect occurred
func (m *Middleware) tryBearerAuth(w http.ResponseWriter, r *http.Request) *authResult {
if m.config.Bearer == nil || m.oidcHandler == nil {
return &authResult{}
}
cookieName := m.oidcHandler.SessionCookieName()
// Handle auth token in query parameter (from OIDC callback)
if m.handleAuthTokenParameter(w, r, cookieName) {
return nil // Redirect occurred
}
// Try session cookie
if result := m.trySessionCookie(r, cookieName); result.authenticated {
return result
}
// Try Authorization header
if result := m.tryAuthorizationHeader(r); result.authenticated {
return result
}
// No valid auth - redirect to OIDC provider
m.oidcHandler.RedirectToProvider(w, r, m.routeID)
return nil // Redirect occurred
}
// handleAuthTokenParameter processes the _auth_token query parameter from OIDC callback
// Returns true if a redirect occurred
func (m *Middleware) handleAuthTokenParameter(w http.ResponseWriter, r *http.Request, cookieName string) bool {
authToken := r.URL.Query().Get("_auth_token")
if authToken == "" {
return false
}
log.WithFields(log.Fields{
"route_id": m.routeID,
"host": r.Host,
}).Info("Found auth token in query parameter, setting cookie and redirecting")
// Validate the token before setting cookie
if !m.oidcHandler.ValidateJWT(authToken) {
log.WithFields(log.Fields{
"route_id": m.routeID,
}).Warn("Invalid token in query parameter")
return false
}
// Set session cookie
cookie := &http.Cookie{
Name: cookieName,
Value: authToken,
Path: "/",
MaxAge: 3600, // 1 hour
HttpOnly: true,
Secure: false, // Set to false for HTTP testing, true for HTTPS in production
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(w, cookie)
// Redirect to same URL without the token parameter
redirectURL := m.buildCleanRedirectURL(r)
log.WithFields(log.Fields{
"route_id": m.routeID,
"redirect_url": redirectURL,
}).Debug("Redirecting to clean URL after setting cookie")
http.Redirect(w, r, redirectURL, http.StatusFound)
return true
}
// buildCleanRedirectURL builds a redirect URL without the _auth_token parameter
func (m *Middleware) buildCleanRedirectURL(r *http.Request) string {
cleanURL := *r.URL
q := cleanURL.Query()
q.Del("_auth_token")
cleanURL.RawQuery = q.Encode()
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
return fmt.Sprintf("%s://%s%s", scheme, r.Host, cleanURL.String())
}
// trySessionCookie attempts authentication using a session cookie
func (m *Middleware) trySessionCookie(r *http.Request, cookieName string) *authResult {
log.WithFields(log.Fields{
"route_id": m.routeID,
"cookie_name": cookieName,
"host": r.Host,
"path": r.URL.Path,
}).Debug("Checking for session cookie")
cookie, err := r.Cookie(cookieName)
if err != nil || cookie.Value == "" {
log.WithFields(log.Fields{
"route_id": m.routeID,
"error": err,
}).Debug("No session cookie found")
return &authResult{}
}
log.WithFields(log.Fields{
"route_id": m.routeID,
"cookie_name": cookieName,
}).Debug("Session cookie found, validating JWT")
if !m.oidcHandler.ValidateJWT(cookie.Value) {
log.WithFields(log.Fields{
"route_id": m.routeID,
}).Debug("JWT validation failed for session cookie")
return &authResult{}
}
return &authResult{
authenticated: true,
method: "bearer_session",
userID: m.oidcHandler.ExtractUserID(cookie.Value),
}
}
// tryAuthorizationHeader attempts authentication using the Authorization header
func (m *Middleware) tryAuthorizationHeader(r *http.Request) *authResult {
authHeader := r.Header.Get("Authorization")
if authHeader == "" || !strings.HasPrefix(authHeader, "Bearer ") {
return &authResult{}
}
token := strings.TrimPrefix(authHeader, "Bearer ")
if !m.oidcHandler.ValidateJWT(token) {
return &authResult{}
}
return &authResult{
authenticated: true,
method: "bearer",
userID: m.oidcHandler.ExtractUserID(token),
}
}
// rejectRequest rejects an unauthenticated request
func (m *Middleware) rejectRequest(w http.ResponseWriter, r *http.Request) {
log.WithFields(log.Fields{
"route_id": m.routeID,
"path": r.URL.Path,
}).Warn("Authentication failed")
if m.rejectResponse != nil {
m.rejectResponse(w, r)
} else {
w.Header().Set("WWW-Authenticate", `Bearer realm="Restricted"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
}
// continueWithAuth continues the request with authenticated user info
func (m *Middleware) continueWithAuth(w http.ResponseWriter, r *http.Request, result *authResult) {
log.WithFields(log.Fields{
"route_id": m.routeID,
"auth_method": result.method,
"user_id": result.userID,
"path": r.URL.Path,
}).Debug("Authentication successful")
// Store auth info in headers for logging
r.Header.Set("X-Auth-Method", result.method)
r.Header.Set("X-Auth-User-ID", result.userID)
// Continue to next handler
m.next.ServeHTTP(w, r)
}
// Wrap wraps an HTTP handler with authentication middleware
func Wrap(next http.Handler, authConfig *Config, routeID string, rejectResponse func(w http.ResponseWriter, r *http.Request), oidcHandler *oidc.Handler) http.Handler {
if authConfig == nil {
authConfig = &Config{} // Empty config = no auth
}
return &Middleware{
next: next,
config: authConfig,
routeID: routeID,
rejectResponse: rejectResponse,
oidcHandler: oidcHandler,
}
}

View File

@@ -0,0 +1,20 @@
package oidc
// Config holds the global OIDC/OAuth configuration
type Config struct {
// OIDC Provider settings
ProviderURL string `env:"NB_OIDC_PROVIDER_URL" json:"provider_url"` // Identity provider URL (e.g., "https://accounts.google.com")
ClientID string `env:"NB_OIDC_CLIENT_ID" json:"client_id"` // OAuth client ID
ClientSecret string `env:"NB_OIDC_CLIENT_SECRET" json:"client_secret"` // OAuth client secret (empty for public clients)
RedirectURL string `env:"NB_OIDC_REDIRECT_URL" json:"redirect_url"` // Redirect URL after auth (e.g., "http://localhost:54321/auth/callback")
Scopes []string `env:"NB_OIDC_SCOPES" json:"scopes"` // Requested scopes (default: ["openid", "profile", "email"])
// JWT Validation settings
JWTKeysLocation string `env:"NB_OIDC_JWT_KEYS_LOCATION" json:"jwt_keys_location"` // JWKS URL for fetching public keys
JWTIssuer string `env:"NB_OIDC_JWT_ISSUER" json:"jwt_issuer"` // Expected issuer claim
JWTAudience []string `env:"NB_OIDC_JWT_AUDIENCE" json:"jwt_audience"` // Expected audience claims
JWTIdpSignkeyRefreshEnabled bool `env:"NB_OIDC_JWT_IDP_SIGNKEY_REFRESH_ENABLED" json:"jwt_idp_signkey_refresh_enabled"` // Enable automatic refresh of signing keys
// Session settings
SessionCookieName string `env:"NB_OIDC_SESSION_COOKIE_NAME" json:"session_cookie_name"` // Cookie name for storing session (default: "auth_session")
}

View File

@@ -0,0 +1,291 @@
package oidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/auth/jwt"
)
// Handler manages OIDC authentication flow
type Handler struct {
config *Config
stateStore *StateStore
jwtValidator *jwt.Validator
}
// NewHandler creates a new OIDC handler
func NewHandler(config *Config, stateStore *StateStore) *Handler {
// Initialize JWT validator
var jwtValidator *jwt.Validator
if config.JWTKeysLocation != "" {
jwtValidator = jwt.NewValidator(
config.JWTIssuer,
config.JWTAudience,
config.JWTKeysLocation,
config.JWTIdpSignkeyRefreshEnabled,
)
}
return &Handler{
config: config,
stateStore: stateStore,
jwtValidator: jwtValidator,
}
}
// RedirectToProvider initiates the OAuth/OIDC authentication flow by redirecting to the provider
func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, routeID string) {
// Generate random state for CSRF protection
state, err := generateRandomString(32)
if err != nil {
log.WithError(err).Error("Failed to generate OIDC state")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Store state with original URL for redirect after auth
// Include the full URL with scheme and host
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
originalURL := fmt.Sprintf("%s://%s%s", scheme, r.Host, r.URL.String())
h.stateStore.Store(state, originalURL, routeID)
// Default scopes if not configured
scopes := h.config.Scopes
if len(scopes) == 0 {
scopes = []string{"openid", "profile", "email"}
}
// Build authorization URL
authURL, err := url.Parse(h.config.ProviderURL)
if err != nil {
log.WithError(err).Error("Invalid OIDC provider URL")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Append /authorize if it doesn't exist (common OIDC endpoint)
if !strings.HasSuffix(authURL.Path, "/authorize") && !strings.HasSuffix(authURL.Path, "/auth") {
authURL.Path = strings.TrimSuffix(authURL.Path, "/") + "/authorize"
}
// Build query parameters
params := url.Values{}
params.Set("client_id", h.config.ClientID)
params.Set("redirect_uri", h.config.RedirectURL)
params.Set("response_type", "code")
params.Set("scope", strings.Join(scopes, " "))
params.Set("state", state)
// Add audience parameter to get an access token for the API
// This ensures we get a proper JWT for the API audience, not just an ID token
if len(h.config.JWTAudience) > 0 && h.config.JWTAudience[0] != h.config.ClientID {
params.Set("audience", h.config.JWTAudience[0])
}
authURL.RawQuery = params.Encode()
log.WithFields(log.Fields{
"route_id": routeID,
"provider_url": authURL.String(),
"redirect_url": h.config.RedirectURL,
"state": state,
}).Info("Redirecting to OIDC provider for authentication")
// Redirect user to identity provider login page
http.Redirect(w, r, authURL.String(), http.StatusFound)
}
// HandleCallback creates an HTTP handler for the OIDC callback endpoint
func (h *Handler) HandleCallback() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get authorization code and state from query parameters
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" || state == "" {
log.Error("Missing code or state in OIDC callback")
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
return
}
// Verify state to prevent CSRF
oidcSt, ok := h.stateStore.Get(state)
if !ok {
log.Error("Invalid or expired OIDC state")
http.Error(w, "Invalid or expired state parameter", http.StatusBadRequest)
return
}
// Delete state to prevent reuse
h.stateStore.Delete(state)
// Exchange authorization code for token
token, err := h.exchangeCodeForToken(code)
if err != nil {
log.WithError(err).Error("Failed to exchange code for token")
http.Error(w, "Authentication failed", http.StatusUnauthorized)
return
}
// Parse the original URL to add the token as a query parameter
origURL, err := url.Parse(oidcSt.OriginalURL)
if err != nil {
log.WithError(err).Error("Failed to parse original URL")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Add token as query parameter so the original domain can set its own cookie
// We use a special parameter name that the auth middleware will look for
q := origURL.Query()
q.Set("_auth_token", token)
origURL.RawQuery = q.Encode()
log.WithFields(log.Fields{
"route_id": oidcSt.RouteID,
"original_url": oidcSt.OriginalURL,
"redirect_url": origURL.String(),
"callback_host": r.Host,
}).Info("OIDC authentication successful, redirecting with token parameter")
// Redirect back to original URL with token parameter
http.Redirect(w, r, origURL.String(), http.StatusFound)
}
}
// exchangeCodeForToken exchanges an authorization code for an access token
func (h *Handler) exchangeCodeForToken(code string) (string, error) {
// Build token endpoint URL
tokenURL, err := url.Parse(h.config.ProviderURL)
if err != nil {
return "", fmt.Errorf("invalid OIDC provider URL: %w", err)
}
// Auth0 uses /oauth/token, standard OIDC uses /token
// Check if path already contains token endpoint
if !strings.Contains(tokenURL.Path, "/token") {
tokenURL.Path = strings.TrimSuffix(tokenURL.Path, "/") + "/oauth/token"
}
// Build request body
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", h.config.RedirectURL)
data.Set("client_id", h.config.ClientID)
// Only include client_secret if it's provided (not needed for public/SPA clients)
if h.config.ClientSecret != "" {
data.Set("client_secret", h.config.ClientSecret)
}
// Make token exchange request
resp, err := http.PostForm(tokenURL.String(), data)
if err != nil {
return "", fmt.Errorf("token exchange request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var tokenResp struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
IDToken string `json:"id_token"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %w", err)
}
if tokenResp.AccessToken == "" {
return "", fmt.Errorf("no access token in response")
}
// Return the ID token if available (contains user claims), otherwise access token
if tokenResp.IDToken != "" {
return tokenResp.IDToken, nil
}
return tokenResp.AccessToken, nil
}
// ValidateJWT validates a JWT token
func (h *Handler) ValidateJWT(tokenString string) bool {
if h.jwtValidator == nil {
log.Error("JWT validation failed: JWT validator not initialized")
return false
}
// Validate the token
ctx := context.Background()
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
if err != nil {
log.WithError(err).Error("JWT validation failed")
// Try to parse token without validation to see what's in it
parts := strings.Split(tokenString, ".")
if len(parts) == 3 {
// Decode payload (middle part)
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err == nil {
log.WithFields(log.Fields{
"payload": string(payload),
}).Debug("Token payload for debugging")
}
}
return false
}
// Token is valid if parsedToken is not nil and Valid is true
return parsedToken != nil && parsedToken.Valid
}
// ExtractUserID extracts the user ID from a JWT token
func (h *Handler) ExtractUserID(tokenString string) string {
if h.jwtValidator == nil {
return ""
}
// Parse the token
ctx := context.Background()
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
if err != nil {
return ""
}
// parsedToken is already *jwtgo.Token from ValidateAndParse
// Create extractor to get user auth info
extractor := jwt.NewClaimsExtractor()
userAuth, err := extractor.ToUserAuth(parsedToken)
if err != nil {
log.WithError(err).Debug("Failed to extract user ID from JWT")
return ""
}
return userAuth.UserId
}
// SessionCookieName returns the configured session cookie name or default
func (h *Handler) SessionCookieName() string {
if h.config.SessionCookieName != "" {
return h.config.SessionCookieName
}
return "auth_session"
}

View File

@@ -0,0 +1,10 @@
package oidc
import "time"
// State represents stored OIDC state information for CSRF protection
type State struct {
OriginalURL string
CreatedAt time.Time
RouteID string
}

View File

@@ -0,0 +1,66 @@
package oidc
import (
"sync"
"time"
)
const (
// StateExpiration is how long OIDC state tokens are valid
StateExpiration = 10 * time.Minute
)
// StateStore manages OIDC state tokens for CSRF protection
type StateStore struct {
mu sync.RWMutex
states map[string]*State
}
// NewStateStore creates a new OIDC state store
func NewStateStore() *StateStore {
return &StateStore{
states: make(map[string]*State),
}
}
// Store saves a state token with associated metadata
func (s *StateStore) Store(stateToken, originalURL, routeID string) {
s.mu.Lock()
defer s.mu.Unlock()
s.states[stateToken] = &State{
OriginalURL: originalURL,
CreatedAt: time.Now(),
RouteID: routeID,
}
// Clean up expired states
s.cleanup()
}
// Get retrieves a state by token
func (s *StateStore) Get(stateToken string) (*State, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
st, ok := s.states[stateToken]
return st, ok
}
// Delete removes a state token
func (s *StateStore) Delete(stateToken string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.states, stateToken)
}
// cleanup removes expired state tokens (must be called with lock held)
func (s *StateStore) cleanup() {
cutoff := time.Now().Add(-StateExpiration)
for k, v := range s.states {
if v.CreatedAt.Before(cutoff) {
delete(s.states, k)
}
}
}

View File

@@ -0,0 +1,15 @@
package oidc
import (
"crypto/rand"
"encoding/base64"
)
// generateRandomString generates a cryptographically secure random string of the specified length
func generateRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
}

View File

@@ -1,613 +0,0 @@
package reverseproxy
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/auth/jwt"
)
const (
// Default values for authentication
defaultSessionCookieName = "auth_session"
defaultPINHeader = "X-PIN"
// OIDC state expiration time
oidcStateExpiration = 10 * time.Minute
// Error messages
errInternalServer = "Internal Server Error"
)
// Global state store for OIDC flow (state -> original URL)
var (
oidcStateStore = &stateStore{
states: make(map[string]*oidcState),
}
)
type stateStore struct {
mu sync.RWMutex
states map[string]*oidcState
}
type oidcState struct {
originalURL string
createdAt time.Time
routeID string
}
func (s *stateStore) Store(state, originalURL, routeID string) {
s.mu.Lock()
defer s.mu.Unlock()
s.states[state] = &oidcState{
originalURL: originalURL,
createdAt: time.Now(),
routeID: routeID,
}
// Clean up expired states
cutoff := time.Now().Add(-oidcStateExpiration)
for k, v := range s.states {
if v.createdAt.Before(cutoff) {
delete(s.states, k)
}
}
}
func (s *stateStore) Get(state string) (*oidcState, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
st, ok := s.states[state]
return st, ok
}
func (s *stateStore) Delete(state string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.states, state)
}
// AuthConfig holds the authentication configuration for a route
// Only ONE auth method should be configured per route
type AuthConfig struct {
// HTTP Basic authentication (username/password)
BasicAuth *BasicAuthConfig
// PIN authentication
PIN *PINConfig
// Bearer token with JWT validation and OAuth/OIDC flow
// When enabled, uses the global OIDCConfig from proxy Config
Bearer *BearerConfig
}
// BasicAuthConfig holds HTTP Basic authentication settings
type BasicAuthConfig struct {
Username string
Password string
}
// PINConfig holds PIN authentication settings
type PINConfig struct {
PIN string
Header string // Header name (default: "X-PIN")
}
// BearerConfig holds JWT/OAuth/OIDC bearer token authentication settings
// The actual OIDC/JWT configuration comes from the global proxy Config.OIDCConfig
// This just enables Bearer auth for a specific route
type BearerConfig struct {
// Enable bearer token authentication for this route
// Uses the global OIDC configuration from proxy Config
Enabled bool
}
// IsEmpty returns true if no auth methods are configured
func (c *AuthConfig) IsEmpty() bool {
if c == nil {
return true
}
return c.BasicAuth == nil && c.PIN == nil && c.Bearer == nil
}
// authMiddlewareHandler is a static middleware that checks AuthConfig
type authMiddlewareHandler struct {
next http.Handler
authConfig *AuthConfig
routeID string
rejectResponse func(w http.ResponseWriter, r *http.Request)
oidcConfig *OIDCConfig // Global OIDC configuration from proxy
jwtValidator *jwt.Validator // JWT validator instance (lazily initialized)
validatorMu sync.Mutex // Mutex for thread-safe validator initialization
}
func (h *authMiddlewareHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// If no auth configured, allow request
if h.authConfig.IsEmpty() {
log.WithFields(log.Fields{
"route_id": h.routeID,
"auth_method": "none",
"path": r.URL.Path,
}).Debug("No authentication configured, allowing request")
r.Header.Set("X-Auth-Method", "none")
h.next.ServeHTTP(w, r)
return
}
var authMethod string
var userID string
authenticated := false
// 1. Check Basic Auth
if h.authConfig.BasicAuth != nil {
if auth := r.Header.Get("Authorization"); auth != "" && strings.HasPrefix(auth, "Basic ") {
encoded := strings.TrimPrefix(auth, "Basic ")
if decoded, err := base64.StdEncoding.DecodeString(encoded); err == nil {
credentials := string(decoded)
parts := strings.SplitN(credentials, ":", 2)
if len(parts) == 2 {
username, password := parts[0], parts[1]
if username == h.authConfig.BasicAuth.Username && password == h.authConfig.BasicAuth.Password {
authenticated = true
authMethod = "basic"
userID = username
}
}
}
}
}
// 2. Check PIN (if not already authenticated)
if !authenticated && h.authConfig.PIN != nil {
headerName := h.authConfig.PIN.Header
if headerName == "" {
headerName = defaultPINHeader
}
if pin := r.Header.Get(headerName); pin != "" {
if pin == h.authConfig.PIN.PIN {
authenticated = true
authMethod = "pin"
userID = "pin_user" // PIN doesn't have a specific user ID
}
}
}
// 3. Check Bearer Token with JWT validation (if not already authenticated)
if !authenticated && h.authConfig.Bearer != nil && h.oidcConfig != nil {
cookieName := h.oidcConfig.SessionCookieName
if cookieName == "" {
cookieName = defaultSessionCookieName
}
// First, check if there's an _auth_token query parameter (from callback redirect)
// This allows us to set the cookie for the current domain
if authToken := r.URL.Query().Get("_auth_token"); authToken != "" {
log.WithFields(log.Fields{
"route_id": h.routeID,
"host": r.Host,
}).Info("Found auth token in query parameter, setting cookie and redirecting")
// Validate the token before setting cookie
if h.validateJWT(authToken) {
// Set cookie for current domain
cookie := &http.Cookie{
Name: cookieName,
Value: authToken,
Path: "/",
MaxAge: 3600, // 1 hour
HttpOnly: true,
Secure: false, // Set to false for HTTP testing, true for HTTPS in production
SameSite: http.SameSiteLaxMode,
}
http.SetCookie(w, cookie)
// Redirect to same URL without the token parameter
cleanURL := *r.URL
q := cleanURL.Query()
q.Del("_auth_token")
cleanURL.RawQuery = q.Encode()
// Build full URL with scheme and host
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
redirectURL := fmt.Sprintf("%s://%s%s", scheme, r.Host, cleanURL.String())
log.WithFields(log.Fields{
"route_id": h.routeID,
"redirect_url": redirectURL,
}).Debug("Redirecting to clean URL after setting cookie")
http.Redirect(w, r, redirectURL, http.StatusFound)
return
} else {
log.WithFields(log.Fields{
"route_id": h.routeID,
}).Warn("Invalid token in query parameter")
}
}
// Check if we have an existing session cookie (from OIDC flow)
log.WithFields(log.Fields{
"route_id": h.routeID,
"cookie_name": cookieName,
"host": r.Host,
"path": r.URL.Path,
}).Debug("Checking for session cookie")
if cookie, err := r.Cookie(cookieName); err == nil && cookie.Value != "" {
log.WithFields(log.Fields{
"route_id": h.routeID,
"cookie_name": cookieName,
}).Debug("Session cookie found, validating JWT")
// Validate the JWT token from the session cookie
if h.validateJWT(cookie.Value) {
authenticated = true
authMethod = "bearer_session"
userID = h.extractUserIDFromJWT(cookie.Value)
} else {
log.WithFields(log.Fields{
"route_id": h.routeID,
}).Debug("JWT validation failed for session cookie")
}
} else {
log.WithFields(log.Fields{
"route_id": h.routeID,
"error": err,
}).Debug("No session cookie found")
}
// If no session cookie or validation failed, check Authorization header
if !authenticated {
if auth := r.Header.Get("Authorization"); auth != "" && strings.HasPrefix(auth, "Bearer ") {
token := strings.TrimPrefix(auth, "Bearer ")
// Validate JWT token from Authorization header
if h.validateJWT(token) {
authenticated = true
authMethod = "bearer"
userID = h.extractUserIDFromJWT(token)
}
} else {
// No bearer token and no valid session - redirect to OIDC provider
if h.oidcConfig.ProviderURL != "" {
// Initiate OAuth/OIDC flow
h.redirectToOIDC(w, r)
return
}
}
}
}
// Reject if authentication failed
if !authenticated {
log.WithFields(log.Fields{
"route_id": h.routeID,
"path": r.URL.Path,
"source_ip": extractSourceIP(r),
}).Warn("Authentication failed")
// Call custom reject response or use default
if h.rejectResponse != nil {
h.rejectResponse(w, r)
} else {
// Default: return 401 with WWW-Authenticate header
w.Header().Set("WWW-Authenticate", `Bearer realm="Restricted"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
return
}
log.WithFields(log.Fields{
"route_id": h.routeID,
"auth_method": authMethod,
"user_id": userID,
"path": r.URL.Path,
}).Debug("Authentication successful")
// Store auth info in headers for logging
r.Header.Set("X-Auth-Method", authMethod)
r.Header.Set("X-Auth-User-ID", userID)
// Continue to next handler
h.next.ServeHTTP(w, r)
}
// redirectToOIDC initiates the OAuth/OIDC authentication flow
func (h *authMiddlewareHandler) redirectToOIDC(w http.ResponseWriter, r *http.Request) {
// Generate random state for CSRF protection
state, err := generateRandomString(32)
if err != nil {
log.WithError(err).Error("Failed to generate OIDC state")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Store state with original URL for redirect after auth
// Include the full URL with scheme and host
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
originalURL := fmt.Sprintf("%s://%s%s", scheme, r.Host, r.URL.String())
oidcStateStore.Store(state, originalURL, h.routeID)
// Default scopes if not configured
scopes := h.oidcConfig.Scopes
if len(scopes) == 0 {
scopes = []string{"openid", "profile", "email"}
}
// Build authorization URL
authURL, err := url.Parse(h.oidcConfig.ProviderURL)
if err != nil {
log.WithError(err).Error("Invalid OIDC provider URL")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Append /authorize if it doesn't exist (common OIDC endpoint)
if !strings.HasSuffix(authURL.Path, "/authorize") && !strings.HasSuffix(authURL.Path, "/auth") {
authURL.Path = strings.TrimSuffix(authURL.Path, "/") + "/authorize"
}
// Build query parameters
params := url.Values{}
params.Set("client_id", h.oidcConfig.ClientID)
params.Set("redirect_uri", h.oidcConfig.RedirectURL)
params.Set("response_type", "code")
params.Set("scope", strings.Join(scopes, " "))
params.Set("state", state)
// Add audience parameter to get an access token for the API
// This ensures we get a proper JWT for the API audience, not just an ID token
if len(h.oidcConfig.JWTAudience) > 0 && h.oidcConfig.JWTAudience[0] != h.oidcConfig.ClientID {
params.Set("audience", h.oidcConfig.JWTAudience[0])
}
authURL.RawQuery = params.Encode()
log.WithFields(log.Fields{
"route_id": h.routeID,
"provider_url": authURL.String(),
"redirect_url": h.oidcConfig.RedirectURL,
"state": state,
}).Info("Redirecting to OIDC provider for authentication")
// Redirect user to identity provider login page
http.Redirect(w, r, authURL.String(), http.StatusFound)
}
// generateRandomString generates a cryptographically secure random string
func generateRandomString(length int) (string, error) {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
}
// HandleOIDCCallback handles the callback from the OIDC provider
// This should be registered as a route handler for the callback URL
func HandleOIDCCallback(oidcConfig *OIDCConfig) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get authorization code and state from query parameters
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" || state == "" {
log.Error("Missing code or state in OIDC callback")
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
return
}
// Verify state to prevent CSRF
oidcSt, ok := oidcStateStore.Get(state)
if !ok {
log.Error("Invalid or expired OIDC state")
http.Error(w, "Invalid or expired state parameter", http.StatusBadRequest)
return
}
// Delete state to prevent reuse
oidcStateStore.Delete(state)
// Exchange authorization code for token
token, err := exchangeCodeForToken(code, oidcConfig)
if err != nil {
log.WithError(err).Error("Failed to exchange code for token")
http.Error(w, "Authentication failed", http.StatusUnauthorized)
return
}
// Parse the original URL to add the token as a query parameter
origURL, err := url.Parse(oidcSt.originalURL)
if err != nil {
log.WithError(err).Error("Failed to parse original URL")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Add token as query parameter so the original domain can set its own cookie
// We use a special parameter name that the auth middleware will look for
q := origURL.Query()
q.Set("_auth_token", token)
origURL.RawQuery = q.Encode()
log.WithFields(log.Fields{
"route_id": oidcSt.routeID,
"original_url": oidcSt.originalURL,
"redirect_url": origURL.String(),
"callback_host": r.Host,
}).Info("OIDC authentication successful, redirecting with token parameter")
// Redirect back to original URL with token parameter
http.Redirect(w, r, origURL.String(), http.StatusFound)
}
}
// exchangeCodeForToken exchanges an authorization code for an access token
func exchangeCodeForToken(code string, config *OIDCConfig) (string, error) {
// Build token endpoint URL
tokenURL, err := url.Parse(config.ProviderURL)
if err != nil {
return "", fmt.Errorf("invalid OIDC provider URL: %w", err)
}
// Auth0 uses /oauth/token, standard OIDC uses /token
// Check if path already contains token endpoint
if !strings.Contains(tokenURL.Path, "/token") {
tokenURL.Path = strings.TrimSuffix(tokenURL.Path, "/") + "/oauth/token"
}
// Build request body
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", config.RedirectURL)
data.Set("client_id", config.ClientID)
// Only include client_secret if it's provided (not needed for public/SPA clients)
if config.ClientSecret != "" {
data.Set("client_secret", config.ClientSecret)
}
// Make token exchange request
resp, err := http.PostForm(tokenURL.String(), data)
if err != nil {
return "", fmt.Errorf("token exchange request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var tokenResp struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
IDToken string `json:"id_token"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %w", err)
}
if tokenResp.AccessToken == "" {
return "", fmt.Errorf("no access token in response")
}
// Return the ID token if available (contains user claims), otherwise access token
if tokenResp.IDToken != "" {
return tokenResp.IDToken, nil
}
return tokenResp.AccessToken, nil
}
// getOrInitValidator lazily initializes and returns the JWT validator
func (h *authMiddlewareHandler) getOrInitValidator() *jwt.Validator {
h.validatorMu.Lock()
defer h.validatorMu.Unlock()
if h.jwtValidator == nil {
h.jwtValidator = jwt.NewValidator(
h.oidcConfig.JWTIssuer,
h.oidcConfig.JWTAudience,
h.oidcConfig.JWTKeysLocation,
h.oidcConfig.JWTIdpSignkeyRefreshEnabled,
)
}
return h.jwtValidator
}
// validateJWT validates a JWT token using the handler's JWT validator
func (h *authMiddlewareHandler) validateJWT(tokenString string) bool {
if h.oidcConfig == nil || h.oidcConfig.JWTKeysLocation == "" {
log.Error("JWT validation failed: OIDC config or JWTKeysLocation is missing")
return false
}
// Get or initialize validator
validator := h.getOrInitValidator()
// Validate the token
ctx := context.Background()
parsedToken, err := validator.ValidateAndParse(ctx, tokenString)
if err != nil {
log.WithError(err).Error("JWT validation failed")
// Try to parse token without validation to see what's in it
parts := strings.Split(tokenString, ".")
if len(parts) == 3 {
// Decode payload (middle part)
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err == nil {
log.WithFields(log.Fields{
"payload": string(payload),
}).Debug("Token payload for debugging")
}
}
return false
}
// Token is valid if parsedToken is not nil and Valid is true
return parsedToken != nil && parsedToken.Valid
}
// extractUserIDFromJWT extracts the user ID from a JWT token
func (h *authMiddlewareHandler) extractUserIDFromJWT(tokenString string) string {
if h.jwtValidator == nil {
return ""
}
// Parse the token
ctx := context.Background()
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
if err != nil {
return ""
}
// parsedToken is already *jwtgo.Token from ValidateAndParse
// Create extractor to get user auth info
extractor := jwt.NewClaimsExtractor()
userAuth, err := extractor.ToUserAuth(parsedToken)
if err != nil {
log.WithError(err).Debug("Failed to extract user ID from JWT")
return ""
}
return userAuth.UserId
}
// wrapWithAuth wraps a handler with the static authentication middleware
// This ALWAYS runs (even when authConfig is nil or empty)
func wrapWithAuth(next http.Handler, authConfig *AuthConfig, routeID string, rejectResponse func(w http.ResponseWriter, r *http.Request), oidcConfig *OIDCConfig) http.Handler {
if authConfig == nil {
authConfig = &AuthConfig{} // Empty config = no auth
}
return &authMiddlewareHandler{
next: next,
authConfig: authConfig,
routeID: routeID,
rejectResponse: rejectResponse,
oidcConfig: oidcConfig,
}
}

View File

@@ -0,0 +1,103 @@
package reverseproxy
import (
"net"
"net/http"
"net/http/httputil"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/auth/oidc"
)
// Config holds the reverse proxy configuration
type Config struct {
// ListenAddress is the address to listen on for HTTPS (default ":443")
ListenAddress string
// HTTPListenAddress is the address for HTTP (default ":80")
// Used for ACME challenges when HTTPS is enabled, or as main listener when HTTPS is disabled
HTTPListenAddress string
// EnableHTTPS enables automatic HTTPS with Let's Encrypt
EnableHTTPS bool
// TLSEmail is the email for Let's Encrypt registration
TLSEmail string
// CertCacheDir is the directory to cache certificates (default "./certs")
CertCacheDir string
// RequestDataCallback is called for each proxied request with metrics
RequestDataCallback RequestDataCallback
// OIDCConfig is the global OIDC/OAuth configuration for authentication
// This is shared across all routes that use Bearer authentication
// If nil, routes with Bearer auth will fail to initialize
OIDCConfig *oidc.Config
}
// RouteConfig defines a routing configuration
type RouteConfig struct {
// ID is a unique identifier for this route
ID string
// Domain is the domain to listen on (e.g., "example.com" or "*" for all)
Domain string
// PathMappings defines paths that should be forwarded to specific ports
// Key is the path prefix (e.g., "/", "/api", "/admin")
// Value is the target IP:port (e.g., "192.168.1.100:3000")
// Must have at least one entry. Use "/" or "" for the default/catch-all route.
PathMappings map[string]string
// Conn is the network connection to use for this route
// This allows routing through specific tunnels (e.g., WireGuard) per route
// This connection will be reused for all requests to this route
Conn net.Conn
// AuthConfig is optional authentication configuration for this route
// Configure ONE of: BasicAuth, PIN, or Bearer (JWT/OIDC)
// If nil, requests pass through without authentication
AuthConfig *auth.Config
// AuthRejectResponse is an optional custom response for authentication failures
// If nil, returns 401 Unauthorized with WWW-Authenticate header
AuthRejectResponse func(w http.ResponseWriter, r *http.Request)
}
// routeEntry represents a compiled route with its proxy
type routeEntry struct {
routeConfig *RouteConfig
path string
target string
proxy *httputil.ReverseProxy
handler http.Handler // handler wraps proxy with middleware (auth, logging, etc.)
}
// RequestDataCallback is called for each proxied request with metrics
type RequestDataCallback func(data RequestData)
// RequestData contains metrics for a proxied request
type RequestData struct {
ServiceID string
Host string
Path string
DurationMs int64
Method string
ResponseCode int32
SourceIP string
AuthMechanism string
UserID string
AuthSuccess bool
}
// responseWriter wraps http.ResponseWriter to capture status code
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}

View File

@@ -0,0 +1,54 @@
package reverseproxy
import (
"fmt"
"net"
"sync"
"time"
)
// defaultConn is a lazy connection wrapper that uses the standard network dialer
// This is useful for testing or development when not using WireGuard tunnels
type defaultConn struct {
dialer *net.Dialer
mu sync.Mutex
conns map[string]net.Conn // cache connections by "network:address"
}
func (dc *defaultConn) Read(b []byte) (n int, err error) {
return 0, fmt.Errorf("Read not supported on defaultConn - use dial via Transport")
}
func (dc *defaultConn) Write(b []byte) (n int, err error) {
return 0, fmt.Errorf("Write not supported on defaultConn - use dial via Transport")
}
func (dc *defaultConn) Close() error {
dc.mu.Lock()
defer dc.mu.Unlock()
for _, conn := range dc.conns {
conn.Close()
}
dc.conns = make(map[string]net.Conn)
return nil
}
func (dc *defaultConn) LocalAddr() net.Addr { return nil }
func (dc *defaultConn) RemoteAddr() net.Addr { return nil }
func (dc *defaultConn) SetDeadline(t time.Time) error { return nil }
func (dc *defaultConn) SetReadDeadline(t time.Time) error { return nil }
func (dc *defaultConn) SetWriteDeadline(t time.Time) error { return nil }
// NewDefaultConn creates a connection wrapper that uses the standard network dialer
// This is useful for testing or development when not using WireGuard tunnels
// The actual dialing happens when the HTTP Transport calls DialContext
func NewDefaultConn() net.Conn {
return &defaultConn{
dialer: &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
},
conns: make(map[string]net.Conn),
}
}

View File

@@ -0,0 +1,262 @@
package reverseproxy
import (
"context"
"net"
"net/http"
"net/http/httputil"
"net/url"
"sort"
"strings"
"time"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/auth"
)
// buildHandler creates the main HTTP handler with router for static endpoints
func (p *Proxy) buildHandler() http.Handler {
router := mux.NewRouter()
// Register static endpoints
router.HandleFunc("/auth/callback", p.handleOIDCCallback).Methods("GET")
// Catch-all handler for dynamic proxy routing
router.PathPrefix("/").HandlerFunc(p.handleProxyRequest)
return router
}
// handleProxyRequest handles all dynamic proxy requests
func (p *Proxy) handleProxyRequest(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
routeEntry := p.findRoute(r.Host, r.URL.Path)
if routeEntry == nil {
log.Warnf("No route found for host=%s path=%s", r.Host, r.URL.Path)
http.NotFound(w, r)
return
}
rw := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
routeEntry.handler.ServeHTTP(rw, r)
if p.requestCallback != nil {
duration := time.Since(startTime)
host := r.Host
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
// Get auth info from headers set by auth middleware
authMechanism := r.Header.Get("X-Auth-Method")
if authMechanism == "" {
authMechanism = "none"
}
userID := r.Header.Get("X-Auth-User-ID")
// Determine auth success based on status code
authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden
// Extract source IP directly
sourceIP := extractSourceIP(r)
data := RequestData{
ServiceID: routeEntry.routeConfig.ID,
Host: host,
Path: r.URL.Path,
DurationMs: duration.Milliseconds(),
Method: r.Method,
ResponseCode: int32(rw.statusCode),
SourceIP: sourceIP,
AuthMechanism: authMechanism,
UserID: userID,
AuthSuccess: authSuccess,
}
p.requestCallback(data)
}
}
// findRoute finds the matching route for a given host and path
func (p *Proxy) findRoute(host, path string) *routeEntry {
p.mu.RLock()
defer p.mu.RUnlock()
// Strip port from host
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
// O(1) lookup by host
routeConfig, exists := p.routes[host]
if !exists {
return nil
}
// Build list of route entries sorted by path specificity
var entries []*routeEntry
// Create entries for each path mapping
for routePath, target := range routeConfig.PathMappings {
proxy := p.createProxy(routeConfig, target)
// ALWAYS wrap proxy with auth middleware (even if no auth configured)
// This ensures consistent auth handling and logging
handler := auth.Wrap(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.oidcHandler)
// Log auth configuration
if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() {
var authType string
if routeConfig.AuthConfig.BasicAuth != nil {
authType = "basic_auth"
} else if routeConfig.AuthConfig.PIN != nil {
authType = "pin"
} else if routeConfig.AuthConfig.Bearer != nil {
authType = "bearer_jwt"
}
log.WithFields(log.Fields{
"route_id": routeConfig.ID,
"auth_type": authType,
}).Debug("Auth middleware enabled for route")
} else {
log.WithFields(log.Fields{
"route_id": routeConfig.ID,
}).Debug("No authentication configured for route")
}
entries = append(entries, &routeEntry{
routeConfig: routeConfig,
path: routePath,
target: target,
proxy: proxy,
handler: handler,
})
}
// Sort by path specificity (longest first)
sort.Slice(entries, func(i, j int) bool {
pi, pj := entries[i].path, entries[j].path
// Empty string or "/" goes last (catch-all)
if pi == "" || pi == "/" {
return false
}
if pj == "" || pj == "/" {
return true
}
return len(pi) > len(pj)
})
// Find first matching entry
for _, entry := range entries {
if entry.path == "" || entry.path == "/" {
// Catch-all route
return entry
}
if strings.HasPrefix(path, entry.path) {
return entry
}
}
return nil
}
// createProxy creates a reverse proxy for a target with the route's connection
func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy {
// Parse target URL
targetURL, err := url.Parse("http://" + target)
if err != nil {
log.Errorf("Failed to parse target URL %s: %v", target, err)
// Return a proxy that returns 502
return &httputil.ReverseProxy{
Director: func(req *http.Request) {},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "Bad Gateway", http.StatusBadGateway)
},
}
}
// Create reverse proxy
proxy := httputil.NewSingleHostReverseProxy(targetURL)
// Check if this is a defaultConn (for testing)
if dc, ok := routeConfig.Conn.(*defaultConn); ok {
// For defaultConn, use its dialer directly
proxy.Transport = &http.Transport{
DialContext: dc.dialer.DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
log.Infof("Using default network dialer for route %s (testing mode)", routeConfig.ID)
} else {
// Configure transport to use the provided connection (WireGuard, etc.)
proxy.Transport = &http.Transport{
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
log.Debugf("Using custom connection for route %s to %s", routeConfig.ID, address)
return routeConfig.Conn, nil
},
MaxIdleConns: 1,
MaxIdleConnsPerHost: 1,
IdleConnTimeout: 0, // Keep alive indefinitely
DisableKeepAlives: false,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
log.Infof("Using custom connection for route %s", routeConfig.ID)
}
// Custom error handler
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
}
return proxy
}
// handleOIDCCallback handles the global /auth/callback endpoint for all routes
func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
// Check if OIDC handler is available
if p.oidcHandler == nil {
log.Error("OIDC callback received but no OIDC handler configured")
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
return
}
// Use the OIDC handler's callback method
handler := p.oidcHandler.HandleCallback()
handler(w, r)
}
// extractSourceIP extracts the source IP from the request
func extractSourceIP(r *http.Request) string {
// Try X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Try X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}

View File

@@ -1,19 +0,0 @@
package reverseproxy
// RequestDataCallback is called for each request that passes through the proxy
type RequestDataCallback func(data RequestData)
// RequestData contains metadata about a proxied request
type RequestData struct {
ServiceID string
Host string
Path string
DurationMs int64
Method string
ResponseCode int32
SourceIP string
AuthMechanism string
UserID string
AuthSuccess bool
}

View File

@@ -1,23 +1,13 @@
package reverseproxy
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"sort"
"strings"
"sync"
"time"
"github.com/gorilla/mux"
"golang.org/x/crypto/acme/autocert"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/auth/oidc"
)
// Proxy wraps a reverse proxy with dynamic routing
@@ -30,90 +20,7 @@ type Proxy struct {
autocertManager *autocert.Manager
isRunning bool
requestCallback RequestDataCallback
}
// Config holds the reverse proxy configuration
type Config struct {
// ListenAddress is the address to listen on for HTTPS (default ":443")
ListenAddress string
// HTTPListenAddress is the address for HTTP (default ":80")
// Used for ACME challenges when HTTPS is enabled, or as main listener when HTTPS is disabled
HTTPListenAddress string
// EnableHTTPS enables automatic HTTPS with Let's Encrypt
EnableHTTPS bool
// TLSEmail is the email for Let's Encrypt registration
TLSEmail string
// CertCacheDir is the directory to cache certificates (default "./certs")
CertCacheDir string
// RequestDataCallback is called for each proxied request with metrics
RequestDataCallback RequestDataCallback
// OIDCConfig is the global OIDC/OAuth configuration for authentication
// This is shared across all routes that use Bearer authentication
// If nil, routes with Bearer auth will fail to initialize
OIDCConfig *OIDCConfig
}
// OIDCConfig holds the global OIDC/OAuth configuration
type OIDCConfig struct {
// OIDC Provider settings
ProviderURL string `env:"NB_OIDC_PROVIDER_URL" json:"provider_url"` // Identity provider URL (e.g., "https://accounts.google.com")
ClientID string `env:"NB_OIDC_CLIENT_ID" json:"client_id"` // OAuth client ID
ClientSecret string `env:"NB_OIDC_CLIENT_SECRET" json:"client_secret"` // OAuth client secret (empty for public clients)
RedirectURL string `env:"NB_OIDC_REDIRECT_URL" json:"redirect_url"` // Redirect URL after auth (e.g., "http://localhost:54321/auth/callback")
Scopes []string `env:"NB_OIDC_SCOPES" json:"scopes"` // Requested scopes (default: ["openid", "profile", "email"])
// JWT Validation settings
JWTKeysLocation string `env:"NB_OIDC_JWT_KEYS_LOCATION" json:"jwt_keys_location"` // JWKS URL for fetching public keys
JWTIssuer string `env:"NB_OIDC_JWT_ISSUER" json:"jwt_issuer"` // Expected issuer claim
JWTAudience []string `env:"NB_OIDC_JWT_AUDIENCE" json:"jwt_audience"` // Expected audience claims
JWTIdpSignkeyRefreshEnabled bool `env:"NB_OIDC_JWT_IDP_SIGNKEY_REFRESH_ENABLED" json:"jwt_idp_signkey_refresh_enabled"` // Enable automatic refresh of signing keys
// Session settings
SessionCookieName string `env:"NB_OIDC_SESSION_COOKIE_NAME" json:"session_cookie_name"` // Cookie name for storing session (default: "auth_session")
}
// RouteConfig defines a routing configuration
type RouteConfig struct {
// ID is a unique identifier for this route
ID string
// Domain is the domain to listen on (e.g., "example.com" or "*" for all)
Domain string
// PathMappings defines paths that should be forwarded to specific ports
// Key is the path prefix (e.g., "/", "/api", "/admin")
// Value is the target IP:port (e.g., "192.168.1.100:3000")
// Must have at least one entry. Use "/" or "" for the default/catch-all route.
PathMappings map[string]string
// Conn is the network connection to use for this route
// This allows routing through specific tunnels (e.g., WireGuard) per route
// This connection will be reused for all requests to this route
Conn net.Conn
// AuthConfig is optional authentication configuration for this route
// Configure ONE of: BasicAuth, PIN, or Bearer (JWT/OIDC)
// If nil, requests pass through without authentication
AuthConfig *AuthConfig
// AuthRejectResponse is an optional custom response for authentication failures
// If nil, returns 401 Unauthorized with WWW-Authenticate header
AuthRejectResponse func(w http.ResponseWriter, r *http.Request)
}
// routeEntry represents a compiled route with its proxy
type routeEntry struct {
routeConfig *RouteConfig
path string
target string
proxy *httputil.ReverseProxy
handler http.Handler // handler wraps proxy with middleware (auth, logging, etc.)
oidcHandler *oidc.Handler
}
// New creates a new reverse proxy
@@ -148,670 +55,17 @@ func New(config Config) (*Proxy, error) {
requestCallback: config.RequestDataCallback,
}
// Initialize OIDC handler if OIDC is configured
// The handler internally creates and manages its own state store
if config.OIDCConfig != nil {
stateStore := oidc.NewStateStore()
p.oidcHandler = oidc.NewHandler(config.OIDCConfig, stateStore)
}
return p, nil
}
// Start starts the reverse proxy server
func (p *Proxy) Start() error {
p.mu.Lock()
if p.isRunning {
p.mu.Unlock()
return fmt.Errorf("reverse proxy already running")
}
p.isRunning = true
p.mu.Unlock()
// Build the main HTTP handler
handler := p.buildHandler()
if p.config.EnableHTTPS {
// Setup autocert manager with dynamic host policy
p.autocertManager = &autocert.Manager{
Cache: autocert.DirCache(p.config.CertCacheDir),
Prompt: autocert.AcceptTOS,
Email: p.config.TLSEmail,
HostPolicy: p.dynamicHostPolicy, // Use dynamic policy based on routes
}
// Start HTTP server for ACME challenges
p.httpServer = &http.Server{
Addr: p.config.HTTPListenAddress,
Handler: p.autocertManager.HTTPHandler(nil),
}
go func() {
log.Infof("Starting HTTP server on %s for ACME challenges", p.config.HTTPListenAddress)
if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Errorf("HTTP server error: %v", err)
}
}()
// Start HTTPS server
p.server = &http.Server{
Addr: p.config.ListenAddress,
Handler: handler,
TLSConfig: p.autocertManager.TLSConfig(),
}
go func() {
log.Infof("Starting HTTPS server on %s", p.config.ListenAddress)
if err := p.server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
log.Errorf("HTTPS server error: %v", err)
p.mu.Lock()
p.isRunning = false
p.mu.Unlock()
}
}()
} else {
// Start HTTP server only
p.server = &http.Server{
Addr: p.config.HTTPListenAddress,
Handler: handler,
}
go func() {
log.Infof("Starting HTTP server on %s", p.config.HTTPListenAddress)
if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Errorf("HTTP server error: %v", err)
p.mu.Lock()
p.isRunning = false
p.mu.Unlock()
}
}()
}
log.Infof("Reverse proxy started with %d route(s)", len(p.routes))
return nil
}
// dynamicHostPolicy is a custom host policy that allows certificates for any domain
// that has a configured route
func (p *Proxy) dynamicHostPolicy(ctx context.Context, host string) error {
p.mu.RLock()
defer p.mu.RUnlock()
// Strip port if present
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
// O(1) lookup for exact domain match
if _, exists := p.routes[host]; exists {
log.Infof("Allowing certificate for domain: %s", host)
return nil
}
log.Warnf("Rejecting certificate request for unknown domain: %s", host)
return fmt.Errorf("domain %s not configured in routes", host)
}
// Stop gracefully stops the reverse proxy
func (p *Proxy) Stop(ctx context.Context) error {
p.mu.Lock()
if !p.isRunning {
p.mu.Unlock()
return fmt.Errorf("reverse proxy not running")
}
p.mu.Unlock()
log.Info("Stopping reverse proxy...")
// Stop HTTPS server
if p.server != nil {
if err := p.server.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown HTTPS server: %w", err)
}
}
// Stop HTTP server (ACME challenge server)
if p.httpServer != nil {
if err := p.httpServer.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown HTTP server: %w", err)
}
}
p.mu.Lock()
p.isRunning = false
p.mu.Unlock()
log.Info("Reverse proxy stopped")
return nil
}
// buildHandler creates the main HTTP handler with router for static endpoints
func (p *Proxy) buildHandler() http.Handler {
router := mux.NewRouter()
// Register static endpoints
router.HandleFunc("/auth/callback", p.handleOIDCCallback).Methods("GET")
// Catch-all handler for dynamic proxy routing
router.PathPrefix("/").HandlerFunc(p.handleProxyRequest)
return router
}
// handleProxyRequest handles all dynamic proxy requests
func (p *Proxy) handleProxyRequest(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
routeEntry := p.findRoute(r.Host, r.URL.Path)
if routeEntry == nil {
log.Warnf("No route found for host=%s path=%s", r.Host, r.URL.Path)
http.NotFound(w, r)
return
}
rw := &responseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
routeEntry.handler.ServeHTTP(rw, r)
if p.requestCallback != nil {
duration := time.Since(startTime)
host := r.Host
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
authMechanism := r.Header.Get("X-Auth-Method")
if authMechanism == "" {
authMechanism = "none"
}
// Determine auth success based on status code
authSuccess := rw.statusCode != http.StatusUnauthorized && rw.statusCode != http.StatusForbidden
// Extract user ID (this would need to be enhanced to extract from tokens/headers)
_, userID, _ := extractAuthInfo(r, rw.statusCode)
data := RequestData{
ServiceID: routeEntry.routeConfig.ID,
Host: host,
Path: r.URL.Path,
DurationMs: duration.Milliseconds(),
Method: r.Method,
ResponseCode: int32(rw.statusCode),
SourceIP: extractSourceIP(r),
AuthMechanism: authMechanism,
UserID: userID,
AuthSuccess: authSuccess,
}
p.requestCallback(data)
}
}
// findRoute finds the matching route for a given host and path
func (p *Proxy) findRoute(host, path string) *routeEntry {
p.mu.RLock()
defer p.mu.RUnlock()
// Strip port from host
if idx := strings.LastIndex(host, ":"); idx != -1 {
host = host[:idx]
}
// O(1) lookup by host
routeConfig, exists := p.routes[host]
if !exists {
return nil
}
// Build list of route entries sorted by path specificity
var entries []*routeEntry
// Create entries for each path mapping
for routePath, target := range routeConfig.PathMappings {
proxy := p.createProxy(routeConfig, target)
// ALWAYS wrap proxy with auth middleware (even if no auth configured)
// This ensures consistent auth handling and logging
handler := wrapWithAuth(proxy, routeConfig.AuthConfig, routeConfig.ID, routeConfig.AuthRejectResponse, p.config.OIDCConfig)
// Log auth configuration
if routeConfig.AuthConfig != nil && !routeConfig.AuthConfig.IsEmpty() {
var authType string
if routeConfig.AuthConfig.BasicAuth != nil {
authType = "basic_auth"
} else if routeConfig.AuthConfig.PIN != nil {
authType = "pin"
} else if routeConfig.AuthConfig.Bearer != nil {
authType = "bearer_jwt"
}
log.WithFields(log.Fields{
"route_id": routeConfig.ID,
"auth_type": authType,
}).Debug("Auth middleware enabled for route")
} else {
log.WithFields(log.Fields{
"route_id": routeConfig.ID,
}).Debug("No authentication configured for route")
}
entries = append(entries, &routeEntry{
routeConfig: routeConfig,
path: routePath,
target: target,
proxy: proxy,
handler: handler,
})
}
// Sort by path specificity (longest first)
sort.Slice(entries, func(i, j int) bool {
pi, pj := entries[i].path, entries[j].path
// Empty string or "/" goes last (catch-all)
if pi == "" || pi == "/" {
return false
}
if pj == "" || pj == "/" {
return true
}
return len(pi) > len(pj)
})
// Find first matching entry
for _, entry := range entries {
if entry.path == "" || entry.path == "/" {
// Catch-all route
return entry
}
if strings.HasPrefix(path, entry.path) {
return entry
}
}
return nil
}
// createProxy creates a reverse proxy for a target with the route's connection
func (p *Proxy) createProxy(routeConfig *RouteConfig, target string) *httputil.ReverseProxy {
// Parse target URL
targetURL, err := url.Parse("http://" + target)
if err != nil {
log.Errorf("Failed to parse target URL %s: %v", target, err)
// Return a proxy that returns 502
return &httputil.ReverseProxy{
Director: func(req *http.Request) {},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, "Bad Gateway", http.StatusBadGateway)
},
}
}
// Create reverse proxy
proxy := httputil.NewSingleHostReverseProxy(targetURL)
// Check if this is a defaultConn (for testing)
if dc, ok := routeConfig.Conn.(*defaultConn); ok {
// For defaultConn, use its dialer directly
proxy.Transport = &http.Transport{
DialContext: dc.dialer.DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
log.Infof("Using default network dialer for route %s (testing mode)", routeConfig.ID)
} else {
// Configure transport to use the provided connection (WireGuard, etc.)
proxy.Transport = &http.Transport{
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
log.Debugf("Using custom connection for route %s to %s", routeConfig.ID, address)
return routeConfig.Conn, nil
},
MaxIdleConns: 1,
MaxIdleConnsPerHost: 1,
IdleConnTimeout: 0, // Keep alive indefinitely
DisableKeepAlives: false,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
log.Infof("Using custom connection for route %s", routeConfig.ID)
}
// Custom error handler
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
log.Errorf("Proxy error for %s%s: %v", r.Host, r.URL.Path, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
}
return proxy
}
// AddRoute adds a new route configuration
func (p *Proxy) AddRoute(route *RouteConfig) error {
if route == nil {
return fmt.Errorf("route cannot be nil")
}
if route.ID == "" {
return fmt.Errorf("route ID is required")
}
if route.Domain == "" {
return fmt.Errorf("route Domain is required")
}
if len(route.PathMappings) == 0 {
return fmt.Errorf("route must have at least one path mapping")
}
if route.Conn == nil {
return fmt.Errorf("route connection (Conn) is required")
}
p.mu.Lock()
defer p.mu.Unlock()
// Check if route already exists for this domain
if _, exists := p.routes[route.Domain]; exists {
return fmt.Errorf("route for domain %s already exists", route.Domain)
}
// Add route with domain as key
p.routes[route.Domain] = route
log.WithFields(log.Fields{
"route_id": route.ID,
"domain": route.Domain,
"paths": len(route.PathMappings),
}).Info("Added route")
// Note: With this architecture, we don't need to reload the server
// The handler dynamically looks up routes on each request
// Certificates will be obtained automatically when the domain is first accessed
return nil
}
// RemoveRoute removes a route
func (p *Proxy) RemoveRoute(domain string) error {
p.mu.Lock()
defer p.mu.Unlock()
// Check if route exists
if _, exists := p.routes[domain]; !exists {
return fmt.Errorf("route for domain %s not found", domain)
}
// Remove route
delete(p.routes, domain)
log.Infof("Removed route for domain: %s", domain)
return nil
}
// UpdateRoute updates an existing route
func (p *Proxy) UpdateRoute(route *RouteConfig) error {
if route == nil {
return fmt.Errorf("route cannot be nil")
}
if route.ID == "" {
return fmt.Errorf("route ID is required")
}
if route.Domain == "" {
return fmt.Errorf("route Domain is required")
}
p.mu.Lock()
defer p.mu.Unlock()
// Check if route exists for this domain
if _, exists := p.routes[route.Domain]; !exists {
return fmt.Errorf("route for domain %s not found", route.Domain)
}
// Update route using domain as key
p.routes[route.Domain] = route
log.WithFields(log.Fields{
"route_id": route.ID,
"domain": route.Domain,
"paths": len(route.PathMappings),
}).Info("Updated route")
return nil
}
// ListRoutes returns a list of all configured domains
func (p *Proxy) ListRoutes() []string {
p.mu.RLock()
defer p.mu.RUnlock()
domains := make([]string, 0, len(p.routes))
for domain := range p.routes {
domains = append(domains, domain)
}
return domains
}
// GetRoute returns a route configuration by domain
func (p *Proxy) GetRoute(domain string) (*RouteConfig, error) {
p.mu.RLock()
defer p.mu.RUnlock()
route, exists := p.routes[domain]
if !exists {
return nil, fmt.Errorf("route for domain %s not found", domain)
}
return route, nil
}
// IsRunning returns whether the proxy is running
func (p *Proxy) IsRunning() bool {
p.mu.RLock()
defer p.mu.RUnlock()
return p.isRunning
}
// GetConfig returns the proxy configuration
func (p *Proxy) GetConfig() Config {
return p.config
}
// responseWriter wraps http.ResponseWriter to capture status code
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
// extractSourceIP extracts the source IP from the request
func extractSourceIP(r *http.Request) string {
// Try X-Forwarded-For header first
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP in the list
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Try X-Real-IP header
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}
// extractAuthInfo extracts authentication information from the request
// Returns: authMechanism, userID, authSuccess
func extractAuthInfo(r *http.Request, statusCode int) (string, string, bool) {
// Check if authentication succeeded based on status code
// 401 = Unauthorized, 403 = Forbidden
authSuccess := statusCode != http.StatusUnauthorized && statusCode != http.StatusForbidden
// Check for Bearer token (JWT, OAuth2, etc.)
if auth := r.Header.Get("Authorization"); auth != "" {
if strings.HasPrefix(auth, "Bearer ") {
// Extract user ID from JWT if possible (you may want to decode the JWT here)
// For now, we'll just indicate it's a bearer token
return "bearer", extractUserIDFromBearer(auth), authSuccess
}
if strings.HasPrefix(auth, "Basic ") {
// Basic authentication
return "basic", extractUserIDFromBasic(auth), authSuccess
}
// Other authorization schemes
return "other", "", authSuccess
}
// Check for API key in headers
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
return "api_key", "", authSuccess
}
if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" {
return "api_key", "", authSuccess
}
// Check for mutual TLS (client certificate)
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
// Extract Common Name from client certificate
cn := r.TLS.PeerCertificates[0].Subject.CommonName
return "mtls", cn, authSuccess
}
// Check for session cookie (common in web apps)
if cookie, err := r.Cookie("session"); err == nil && cookie.Value != "" {
return "session", "", authSuccess
}
// No authentication detected
return "none", "", authSuccess
}
// extractUserIDFromBearer attempts to extract user ID from Bearer token
// Decodes the JWT (without verification) to extract the user ID from standard claims
func extractUserIDFromBearer(auth string) string {
// Remove "Bearer " prefix
tokenString := strings.TrimPrefix(auth, "Bearer ")
if tokenString == "" {
return ""
}
// JWT format: header.payload.signature
// We only need the payload to extract user ID (no verification needed here)
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
log.Debug("Invalid JWT format: expected 3 parts")
return ""
}
// Decode the payload (second part)
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
log.WithError(err).Debug("Failed to decode JWT payload")
return ""
}
// Parse JSON payload
var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil {
log.WithError(err).Debug("Failed to parse JWT claims")
return ""
}
// Try standard user ID claims in order of preference
// 1. "sub" (standard JWT subject claim)
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
// 2. "user_id" (common in some systems)
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
// 3. "email" (fallback)
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
// 4. "preferred_username" (used by some OIDC providers)
if username, ok := claims["preferred_username"].(string); ok && username != "" {
return username
}
return ""
}
// extractUserIDFromBasic extracts username from Basic auth header
func extractUserIDFromBasic(auth string) string {
// Basic auth format: "Basic base64(username:password)"
_ = strings.TrimPrefix(auth, "Basic ")
// Note: We're not decoding it here for security reasons
// The upstream service should handle the actual authentication
// We just note that basic auth was used
return ""
}
// defaultConn is a lazy connection wrapper that uses the standard network dialer
// This is useful for testing or development when not using WireGuard tunnels
type defaultConn struct {
dialer *net.Dialer
mu sync.Mutex
conns map[string]net.Conn // cache connections by "network:address"
}
func (dc *defaultConn) Read(b []byte) (n int, err error) {
return 0, fmt.Errorf("Read not supported on defaultConn - use dial via Transport")
}
func (dc *defaultConn) Write(b []byte) (n int, err error) {
return 0, fmt.Errorf("Write not supported on defaultConn - use dial via Transport")
}
func (dc *defaultConn) Close() error {
dc.mu.Lock()
defer dc.mu.Unlock()
for _, conn := range dc.conns {
conn.Close()
}
dc.conns = make(map[string]net.Conn)
return nil
}
func (dc *defaultConn) LocalAddr() net.Addr { return nil }
func (dc *defaultConn) RemoteAddr() net.Addr { return nil }
func (dc *defaultConn) SetDeadline(t time.Time) error { return nil }
func (dc *defaultConn) SetReadDeadline(t time.Time) error { return nil }
func (dc *defaultConn) SetWriteDeadline(t time.Time) error { return nil }
// NewDefaultConn creates a connection wrapper that uses the standard network dialer
// This is useful for testing or development when not using WireGuard tunnels
// The actual dialing happens when the HTTP Transport calls DialContext
func NewDefaultConn() net.Conn {
return &defaultConn{
dialer: &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
},
conns: make(map[string]net.Conn),
}
}
// handleOIDCCallback handles the global /auth/callback endpoint for all routes
func (p *Proxy) handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
// Check if OIDC is configured globally
if p.config.OIDCConfig == nil {
log.Error("OIDC callback received but no OIDC config found")
http.Error(w, "Authentication not configured", http.StatusInternalServerError)
return
}
// Use the HandleOIDCCallback function from auth.go with global config
handler := HandleOIDCCallback(p.config.OIDCConfig)
handler(w, r)
}

View File

@@ -0,0 +1,123 @@
package reverseproxy
import (
"fmt"
log "github.com/sirupsen/logrus"
)
// AddRoute adds a new route to the proxy
func (p *Proxy) AddRoute(route *RouteConfig) error {
if route == nil {
return fmt.Errorf("route cannot be nil")
}
if route.ID == "" {
return fmt.Errorf("route ID is required")
}
if route.Domain == "" {
return fmt.Errorf("route Domain is required")
}
if len(route.PathMappings) == 0 {
return fmt.Errorf("route must have at least one path mapping")
}
if route.Conn == nil {
return fmt.Errorf("route connection (Conn) is required")
}
p.mu.Lock()
defer p.mu.Unlock()
// Check if route already exists for this domain
if _, exists := p.routes[route.Domain]; exists {
return fmt.Errorf("route for domain %s already exists", route.Domain)
}
// Add route with domain as key
p.routes[route.Domain] = route
log.WithFields(log.Fields{
"route_id": route.ID,
"domain": route.Domain,
"paths": len(route.PathMappings),
}).Info("Added route")
// Note: With this architecture, we don't need to reload the server
// The handler dynamically looks up routes on each request
// Certificates will be obtained automatically when the domain is first accessed
return nil
}
// RemoveRoute removes a route by domain
func (p *Proxy) RemoveRoute(domain string) error {
p.mu.Lock()
defer p.mu.Unlock()
// Check if route exists
if _, exists := p.routes[domain]; !exists {
return fmt.Errorf("route for domain %s not found", domain)
}
// Remove route
delete(p.routes, domain)
log.Infof("Removed route for domain: %s", domain)
return nil
}
// UpdateRoute updates an existing route
func (p *Proxy) UpdateRoute(route *RouteConfig) error {
if route == nil {
return fmt.Errorf("route cannot be nil")
}
if route.ID == "" {
return fmt.Errorf("route ID is required")
}
if route.Domain == "" {
return fmt.Errorf("route Domain is required")
}
p.mu.Lock()
defer p.mu.Unlock()
// Check if route exists for this domain
if _, exists := p.routes[route.Domain]; !exists {
return fmt.Errorf("route for domain %s not found", route.Domain)
}
// Update route using domain as key
p.routes[route.Domain] = route
log.WithFields(log.Fields{
"route_id": route.ID,
"domain": route.Domain,
"paths": len(route.PathMappings),
}).Info("Updated route")
return nil
}
// ListRoutes returns a list of all configured domains
func (p *Proxy) ListRoutes() []string {
p.mu.RLock()
defer p.mu.RUnlock()
domains := make([]string, 0, len(p.routes))
for domain := range p.routes {
domains = append(domains, domain)
}
return domains
}
// GetRoute returns a route configuration by domain
func (p *Proxy) GetRoute(domain string) (*RouteConfig, error) {
p.mu.RLock()
defer p.mu.RUnlock()
route, exists := p.routes[domain]
if !exists {
return nil, fmt.Errorf("route for domain %s not found", domain)
}
return route, nil
}

View File

@@ -0,0 +1,136 @@
package reverseproxy
import (
"context"
"fmt"
"net/http"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/acme/autocert"
)
// Start starts the reverse proxy server
func (p *Proxy) Start() error {
p.mu.Lock()
if p.isRunning {
p.mu.Unlock()
return fmt.Errorf("reverse proxy already running")
}
p.isRunning = true
p.mu.Unlock()
// Build the main HTTP handler
handler := p.buildHandler()
if p.config.EnableHTTPS {
return p.startHTTPS(handler)
}
return p.startHTTP(handler)
}
// startHTTPS starts the proxy with HTTPS and Let's Encrypt
func (p *Proxy) startHTTPS(handler http.Handler) error {
// Setup autocert manager with dynamic host policy
p.autocertManager = &autocert.Manager{
Prompt: autocert.AcceptTOS,
HostPolicy: p.dynamicHostPolicy,
Cache: autocert.DirCache(p.config.CertCacheDir),
Email: p.config.TLSEmail,
RenewBefore: 0, // Use default
}
// Start HTTP server for ACME challenges
p.httpServer = &http.Server{
Addr: p.config.HTTPListenAddress,
Handler: p.autocertManager.HTTPHandler(nil),
}
go func() {
log.Infof("Starting HTTP server for ACME challenges on %s", p.config.HTTPListenAddress)
if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Errorf("HTTP server error: %v", err)
}
}()
// Start HTTPS server
p.server = &http.Server{
Addr: p.config.ListenAddress,
Handler: handler,
TLSConfig: p.autocertManager.TLSConfig(),
}
log.Infof("Starting HTTPS reverse proxy server on %s", p.config.ListenAddress)
if err := p.server.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
return fmt.Errorf("HTTPS server failed: %w", err)
}
return nil
}
// startHTTP starts the proxy with HTTP only (no TLS)
func (p *Proxy) startHTTP(handler http.Handler) error {
p.server = &http.Server{
Addr: p.config.HTTPListenAddress,
Handler: handler,
}
log.Infof("Starting HTTP reverse proxy server on %s (HTTPS disabled)", p.config.HTTPListenAddress)
if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return fmt.Errorf("HTTP server failed: %w", err)
}
return nil
}
// dynamicHostPolicy validates that a requested host has a configured route
func (p *Proxy) dynamicHostPolicy(ctx context.Context, host string) error {
p.mu.RLock()
defer p.mu.RUnlock()
// Check if we have a route for this host
if _, exists := p.routes[host]; exists {
log.Infof("ACME challenge accepted for configured host: %s", host)
return nil
}
log.Warnf("ACME challenge rejected for unconfigured host: %s", host)
return fmt.Errorf("host %s not configured", host)
}
// Stop gracefully stops the reverse proxy server
func (p *Proxy) Stop(ctx context.Context) error {
p.mu.Lock()
if !p.isRunning {
p.mu.Unlock()
return fmt.Errorf("reverse proxy not running")
}
p.isRunning = false
p.mu.Unlock()
log.Info("Stopping reverse proxy server...")
// Stop HTTP server (for ACME challenges)
if p.httpServer != nil {
if err := p.httpServer.Shutdown(ctx); err != nil {
log.Errorf("Error shutting down HTTP server: %v", err)
}
}
// Stop main server
if p.server != nil {
if err := p.server.Shutdown(ctx); err != nil {
return fmt.Errorf("error shutting down server: %w", err)
}
}
log.Info("Reverse proxy server stopped")
return nil
}
// IsRunning returns whether the proxy is running
func (p *Proxy) IsRunning() bool {
p.mu.RLock()
defer p.mu.RUnlock()
return p.isRunning
}

View File

@@ -10,7 +10,7 @@ import (
"github.com/caarlos0/env/v11"
"github.com/netbirdio/netbird/proxy/internal/reverseproxy"
"github.com/netbirdio/netbird/proxy/internal/auth/oidc"
)
var (
@@ -89,7 +89,7 @@ type Config struct {
CertCacheDir string `json:"cert_cache_dir"`
// OIDCConfig is the global OIDC/OAuth configuration for authentication
OIDCConfig *reverseproxy.OIDCConfig `json:"oidc_config,omitempty"`
OIDCConfig *oidc.Config `json:"oidc_config,omitempty"`
}
// ParseAndLoad parses configuration from environment variables

View File

@@ -9,6 +9,8 @@ import (
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/auth/methods"
"github.com/netbirdio/netbird/proxy/internal/reverseproxy"
grpcpkg "github.com/netbirdio/netbird/proxy/pkg/grpc"
pb "github.com/netbirdio/netbird/proxy/pkg/grpc/proto"
@@ -171,8 +173,8 @@ func (s *Server) Start() error {
// Enable Bearer authentication for the test route
// OIDC configuration is set globally in the proxy config above
testAuthConfig := &reverseproxy.AuthConfig{
Bearer: &reverseproxy.BearerConfig{
testAuthConfig := &auth.Config{
Bearer: &methods.BearerConfig{
Enabled: true,
},
}

View File

@@ -24,7 +24,7 @@ type Info struct {
Version string `json:"version"`
Commit string `json:"commit"`
BuildDate string `json:"build_date"`
GoVersion string `json:"go_version"`
GoVersion string `json:"NewSingleHostReverseProxygo_version"`
OS string `json:"os"`
Arch string `json:"arch"`
}