package auth import ( "context" "crypto/rand" "encoding/base64" "fmt" "net/http" "net/url" "strings" "sync" "time" "github.com/coreos/go-oidc/v3/oidc" "golang.org/x/oauth2" ) const stateExpiration = 10 * time.Minute const callbackPath = "/oauth/callback" // OIDCConfig holds configuration for OIDC authentication type OIDCConfig struct { OIDCProviderURL string OIDCClientID string OIDCClientSecret string OIDCScopes []string DistributionGroups []string } // oidcState stores CSRF state with expiration type oidcState struct { OriginalURL string CreatedAt time.Time } // OIDC implements the Scheme interface for JWT/OIDC authentication type OIDC struct { id, accountId, proxyURL string verifier *oidc.IDTokenVerifier oauthConfig *oauth2.Config states map[string]*oidcState statesMux sync.RWMutex distributionGroups []string } // NewOIDC creates a new OIDC authentication scheme func NewOIDC(ctx context.Context, id, accountId, proxyURL string, cfg OIDCConfig) (*OIDC, error) { if cfg.OIDCProviderURL == "" || cfg.OIDCClientID == "" { return nil, fmt.Errorf("OIDC provider URL and client ID are required") } scopes := cfg.OIDCScopes if len(scopes) == 0 { scopes = []string{oidc.ScopeOpenID, "profile", "email"} } provider, err := oidc.NewProvider(ctx, cfg.OIDCProviderURL) if err != nil { return nil, fmt.Errorf("failed to create OIDC provider: %w", err) } o := &OIDC{ id: id, accountId: accountId, proxyURL: proxyURL, verifier: provider.Verifier(&oidc.Config{ ClientID: cfg.OIDCClientID, }), oauthConfig: &oauth2.Config{ ClientID: cfg.OIDCClientID, ClientSecret: cfg.OIDCClientSecret, Endpoint: provider.Endpoint(), Scopes: scopes, }, states: make(map[string]*oidcState), distributionGroups: cfg.DistributionGroups, } go o.cleanupStates() return o, nil } func (*OIDC) Type() Method { return MethodOIDC } func (o *OIDC) Authenticate(r *http.Request) (string, string) { // Try Authorization: Bearer header if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") { if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" { return userID, "" } } // Try _auth_token query parameter (from OIDC callback redirect) if token := r.URL.Query().Get("_auth_token"); token != "" { if userID := o.validateToken(r.Context(), token); userID != "" { return userID, "" } } // If the request is not authenticated, return a redirect URL for the UI to // route the user through if they select OIDC login. b := make([]byte, 32) _, _ = rand.Read(b) state := base64.URLEncoding.EncodeToString(b) // TODO: this does not work if you are load balancing across multiple proxy servers. o.statesMux.Lock() o.states[state] = &oidcState{OriginalURL: fmt.Sprintf("https://%s%s", r.Host, r.URL), CreatedAt: time.Now()} o.statesMux.Unlock() return "", (&oauth2.Config{ ClientID: o.oauthConfig.ClientID, ClientSecret: o.oauthConfig.ClientSecret, Endpoint: o.oauthConfig.Endpoint, RedirectURL: o.proxyURL + callbackPath, Scopes: o.oauthConfig.Scopes, }).AuthCodeURL(state) } // Middleware returns an http.Handler that handles OIDC callback and flow initiation. func (o *OIDC) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Handle OIDC callback if r.URL.Path == callbackPath { o.handleCallback(w, r) return } next.ServeHTTP(w, r) }) } // validateToken validates a JWT ID token and returns the user ID (subject) // Returns empty string if token is invalid or user's groups don't appear // in the distributionGroups. func (o *OIDC) validateToken(ctx context.Context, token string) string { if o.verifier == nil { return "" } idToken, err := o.verifier.Verify(ctx, token) if err != nil { // TODO: log or return? return "" } // If distribution groups are configured, check if user has access if len(o.distributionGroups) > 0 { var claims struct { Groups []string `json:"groups"` } if err := idToken.Claims(&claims); err != nil { // TODO: log or return? return "" } allowed := make(map[string]struct{}, len(o.distributionGroups)) for _, g := range o.distributionGroups { allowed[g] = struct{}{} } for _, g := range claims.Groups { if _, ok := allowed[g]; ok { return idToken.Subject } } } // Default deny return "" } // handleCallback processes the OIDC callback func (o *OIDC) handleCallback(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") state := r.URL.Query().Get("state") if code == "" || state == "" { http.Error(w, "Invalid callback parameters", http.StatusBadRequest) return } // Verify and consume state o.statesMux.Lock() st, ok := o.states[state] if ok { delete(o.states, state) } o.statesMux.Unlock() if !ok { http.Error(w, "Invalid or expired state", http.StatusBadRequest) return } // Exchange code for token token, err := o.oauthConfig.Exchange(r.Context(), code) if err != nil { http.Error(w, "Authentication failed", http.StatusUnauthorized) return } // Prefer ID token if available idToken := token.AccessToken if id, ok := token.Extra("id_token").(string); ok && id != "" { idToken = id } // Redirect back to original URL with token origURL, err := url.Parse(st.OriginalURL) if err != nil { http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } q := origURL.Query() q.Set("_auth_token", idToken) origURL.RawQuery = q.Encode() http.Redirect(w, r, origURL.String(), http.StatusFound) } // cleanupStates periodically removes expired states func (o *OIDC) cleanupStates() { for range time.Tick(time.Minute) { cutoff := time.Now().Add(-stateExpiration) o.statesMux.Lock() for k, v := range o.states { if v.CreatedAt.Before(cutoff) { delete(o.states, k) } } o.statesMux.Unlock() } }