Files
netbird/proxy/internal/auth/oidc/handler.go
2026-01-16 12:01:52 +01:00

286 lines
8.2 KiB
Go

package oidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/auth/jwt"
)
// Handler manages OIDC authentication flow
type Handler struct {
config *Config
stateStore *StateStore
jwtValidator *jwt.Validator
}
// NewHandler creates a new OIDC handler
func NewHandler(config *Config, stateStore *StateStore) *Handler {
var jwtValidator *jwt.Validator
if config.JWTKeysLocation != "" {
jwtValidator = jwt.NewValidator(
config.JWTIssuer,
config.JWTAudience,
config.JWTKeysLocation,
config.JWTIdpSignkeyRefreshEnabled,
)
}
return &Handler{
config: config,
stateStore: stateStore,
jwtValidator: jwtValidator,
}
}
// RedirectToProvider initiates the OAuth/OIDC authentication flow by redirecting to the provider
func (h *Handler) RedirectToProvider(w http.ResponseWriter, r *http.Request, routeID string) {
// Generate random state for CSRF protection
state, err := generateRandomString(32)
if err != nil {
log.WithError(err).Error("Failed to generate OIDC state")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Store state with original URL for redirect after auth
// Include the full URL with scheme and host
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
originalURL := fmt.Sprintf("%s://%s%s", scheme, r.Host, r.URL.String())
h.stateStore.Store(state, originalURL, routeID)
// Default scopes if not configured
scopes := h.config.Scopes
if len(scopes) == 0 {
scopes = []string{"openid", "profile", "email"}
}
authURL, err := url.Parse(h.config.ProviderURL)
if err != nil {
log.WithError(err).Error("Invalid OIDC provider URL")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Append /authorize if it doesn't exist (common OIDC endpoint)
if !strings.HasSuffix(authURL.Path, "/authorize") && !strings.HasSuffix(authURL.Path, "/auth") {
authURL.Path = strings.TrimSuffix(authURL.Path, "/") + "/authorize"
}
params := url.Values{}
params.Set("client_id", h.config.ClientID)
params.Set("redirect_uri", h.config.RedirectURL)
params.Set("response_type", "code")
params.Set("scope", strings.Join(scopes, " "))
params.Set("state", state)
if len(h.config.JWTAudience) > 0 && h.config.JWTAudience[0] != h.config.ClientID {
params.Set("audience", h.config.JWTAudience[0])
}
authURL.RawQuery = params.Encode()
log.WithFields(log.Fields{
"route_id": routeID,
"provider_url": authURL.String(),
"redirect_url": h.config.RedirectURL,
"state": state,
}).Info("Redirecting to OIDC provider for authentication")
http.Redirect(w, r, authURL.String(), http.StatusFound)
}
// HandleCallback creates an HTTP handler for the OIDC callback endpoint
func (h *Handler) HandleCallback() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Get authorization code and state from query parameters
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" || state == "" {
log.Error("Missing code or state in OIDC callback")
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
return
}
// Verify state to prevent CSRF
oidcSt, ok := h.stateStore.Get(state)
if !ok {
log.Error("Invalid or expired OIDC state")
http.Error(w, "Invalid or expired state parameter", http.StatusBadRequest)
return
}
// Delete state to prevent reuse
h.stateStore.Delete(state)
// Exchange authorization code for token
token, err := h.exchangeCodeForToken(code)
if err != nil {
log.WithError(err).Error("Failed to exchange code for token")
http.Error(w, "Authentication failed", http.StatusUnauthorized)
return
}
// Parse the original URL to add the token as a query parameter
origURL, err := url.Parse(oidcSt.OriginalURL)
if err != nil {
log.WithError(err).Error("Failed to parse original URL")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Add token as query parameter so the original domain can set its own cookie
// We use a special parameter name that the auth middleware will look for
q := origURL.Query()
q.Set("_auth_token", token)
origURL.RawQuery = q.Encode()
log.WithFields(log.Fields{
"route_id": oidcSt.RouteID,
"original_url": oidcSt.OriginalURL,
"redirect_url": origURL.String(),
"callback_host": r.Host,
}).Info("OIDC authentication successful, redirecting with token parameter")
// Redirect back to original URL with token parameter
http.Redirect(w, r, origURL.String(), http.StatusFound)
}
}
// exchangeCodeForToken exchanges an authorization code for an access token
func (h *Handler) exchangeCodeForToken(code string) (string, error) {
// Build token endpoint URL
tokenURL, err := url.Parse(h.config.ProviderURL)
if err != nil {
return "", fmt.Errorf("invalid OIDC provider URL: %w", err)
}
// Auth0 uses /oauth/token, standard OIDC uses /token
// Check if path already contains token endpoint
if !strings.Contains(tokenURL.Path, "/token") {
tokenURL.Path = strings.TrimSuffix(tokenURL.Path, "/") + "/oauth/token"
}
// Build request body
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", h.config.RedirectURL)
data.Set("client_id", h.config.ClientID)
// Only include client_secret if it's provided (not needed for public/SPA clients)
if h.config.ClientSecret != "" {
data.Set("client_secret", h.config.ClientSecret)
}
// Make token exchange request
resp, err := http.PostForm(tokenURL.String(), data)
if err != nil {
return "", fmt.Errorf("token exchange request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var tokenResp struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
IDToken string `json:"id_token"`
}
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %w", err)
}
if tokenResp.AccessToken == "" {
return "", fmt.Errorf("no access token in response")
}
// Return the ID token if available (contains user claims), otherwise access token
if tokenResp.IDToken != "" {
return tokenResp.IDToken, nil
}
return tokenResp.AccessToken, nil
}
// ValidateJWT validates a JWT token
func (h *Handler) ValidateJWT(tokenString string) bool {
if h.jwtValidator == nil {
log.Error("JWT validation failed: JWT validator not initialized")
return false
}
// Validate the token
ctx := context.Background()
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
if err != nil {
log.WithError(err).Error("JWT validation failed")
// Try to parse token without validation to see what's in it
parts := strings.Split(tokenString, ".")
if len(parts) == 3 {
// Decode payload (middle part)
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err == nil {
log.WithFields(log.Fields{
"payload": string(payload),
}).Debug("Token payload for debugging")
}
}
return false
}
// Token is valid if parsedToken is not nil and Valid is true
return parsedToken != nil && parsedToken.Valid
}
// ExtractUserID extracts the user ID from a JWT token
func (h *Handler) ExtractUserID(tokenString string) string {
if h.jwtValidator == nil {
return ""
}
// Parse the token
ctx := context.Background()
parsedToken, err := h.jwtValidator.ValidateAndParse(ctx, tokenString)
if err != nil {
return ""
}
// parsedToken is already *jwtgo.Token from ValidateAndParse
// Create extractor to get user auth info
extractor := jwt.NewClaimsExtractor()
userAuth, err := extractor.ToUserAuth(parsedToken)
if err != nil {
log.WithError(err).Debug("Failed to extract user ID from JWT")
return ""
}
return userAuth.UserId
}
// SessionCookieName returns the configured session cookie name or default
func (h *Handler) SessionCookieName() string {
if h.config.SessionCookieName != "" {
return h.config.SessionCookieName
}
return "auth_session"
}