Files
netbird/proxy/internal/auth/oidc.go
2026-02-03 12:10:23 +00:00

240 lines
5.9 KiB
Go

package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
)
const stateExpiration = 10 * time.Minute
const callbackPath = "/oauth/callback"
// OIDCConfig holds configuration for OIDC authentication
type OIDCConfig struct {
OIDCProviderURL string
OIDCClientID string
OIDCClientSecret string
OIDCScopes []string
DistributionGroups []string
}
// oidcState stores CSRF state with expiration
type oidcState struct {
OriginalURL string
CreatedAt time.Time
}
// OIDC implements the Scheme interface for JWT/OIDC authentication
type OIDC struct {
id, accountId, proxyURL string
verifier *oidc.IDTokenVerifier
oauthConfig *oauth2.Config
states map[string]*oidcState
statesMux sync.RWMutex
distributionGroups []string
}
// NewOIDC creates a new OIDC authentication scheme
func NewOIDC(ctx context.Context, id, accountId, proxyURL string, cfg OIDCConfig) (*OIDC, error) {
if cfg.OIDCProviderURL == "" || cfg.OIDCClientID == "" {
return nil, fmt.Errorf("OIDC provider URL and client ID are required")
}
scopes := cfg.OIDCScopes
if len(scopes) == 0 {
scopes = []string{oidc.ScopeOpenID, "profile", "email"}
}
provider, err := oidc.NewProvider(ctx, cfg.OIDCProviderURL)
if err != nil {
return nil, fmt.Errorf("failed to create OIDC provider: %w", err)
}
o := &OIDC{
id: id,
accountId: accountId,
proxyURL: proxyURL,
verifier: provider.Verifier(&oidc.Config{
ClientID: cfg.OIDCClientID,
}),
oauthConfig: &oauth2.Config{
ClientID: cfg.OIDCClientID,
ClientSecret: cfg.OIDCClientSecret,
Endpoint: provider.Endpoint(),
Scopes: scopes,
},
states: make(map[string]*oidcState),
distributionGroups: cfg.DistributionGroups,
}
go o.cleanupStates()
return o, nil
}
func (*OIDC) Type() Method {
return MethodOIDC
}
func (o *OIDC) Authenticate(r *http.Request) (string, string) {
// Try Authorization: Bearer <token> header
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" {
return userID, ""
}
}
// Try _auth_token query parameter (from OIDC callback redirect)
if token := r.URL.Query().Get("_auth_token"); token != "" {
if userID := o.validateToken(r.Context(), token); userID != "" {
return userID, ""
}
}
// If the request is not authenticated, return a redirect URL for the UI to
// route the user through if they select OIDC login.
b := make([]byte, 32)
_, _ = rand.Read(b)
state := base64.URLEncoding.EncodeToString(b)
// TODO: this does not work if you are load balancing across multiple proxy servers.
o.statesMux.Lock()
o.states[state] = &oidcState{OriginalURL: fmt.Sprintf("https://%s%s", r.Host, r.URL), CreatedAt: time.Now()}
o.statesMux.Unlock()
return "", (&oauth2.Config{
ClientID: o.oauthConfig.ClientID,
ClientSecret: o.oauthConfig.ClientSecret,
Endpoint: o.oauthConfig.Endpoint,
RedirectURL: o.proxyURL + callbackPath,
Scopes: o.oauthConfig.Scopes,
}).AuthCodeURL(state)
}
// Middleware returns an http.Handler that handles OIDC callback and flow initiation.
func (o *OIDC) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Handle OIDC callback
if r.URL.Path == callbackPath {
o.handleCallback(w, r)
return
}
next.ServeHTTP(w, r)
})
}
// validateToken validates a JWT ID token and returns the user ID (subject)
// Returns empty string if token is invalid or user's groups don't appear
// in the distributionGroups.
func (o *OIDC) validateToken(ctx context.Context, token string) string {
if o.verifier == nil {
return ""
}
idToken, err := o.verifier.Verify(ctx, token)
if err != nil {
// TODO: log or return?
return ""
}
// If distribution groups are configured, check if user has access
if len(o.distributionGroups) > 0 {
var claims struct {
Groups []string `json:"groups"`
}
if err := idToken.Claims(&claims); err != nil {
// TODO: log or return?
return ""
}
allowed := make(map[string]struct{}, len(o.distributionGroups))
for _, g := range o.distributionGroups {
allowed[g] = struct{}{}
}
for _, g := range claims.Groups {
if _, ok := allowed[g]; ok {
return idToken.Subject
}
}
}
// Default deny
return ""
}
// handleCallback processes the OIDC callback
func (o *OIDC) handleCallback(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
state := r.URL.Query().Get("state")
if code == "" || state == "" {
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
return
}
// Verify and consume state
o.statesMux.Lock()
st, ok := o.states[state]
if ok {
delete(o.states, state)
}
o.statesMux.Unlock()
if !ok {
http.Error(w, "Invalid or expired state", http.StatusBadRequest)
return
}
// Exchange code for token
token, err := o.oauthConfig.Exchange(r.Context(), code)
if err != nil {
http.Error(w, "Authentication failed", http.StatusUnauthorized)
return
}
// Prefer ID token if available
idToken := token.AccessToken
if id, ok := token.Extra("id_token").(string); ok && id != "" {
idToken = id
}
// Redirect back to original URL with token
origURL, err := url.Parse(st.OriginalURL)
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
q := origURL.Query()
q.Set("_auth_token", idToken)
origURL.RawQuery = q.Encode()
http.Redirect(w, r, origURL.String(), http.StatusFound)
}
// cleanupStates periodically removes expired states
func (o *OIDC) cleanupStates() {
for range time.Tick(time.Minute) {
cutoff := time.Now().Add(-stateExpiration)
o.statesMux.Lock()
for k, v := range o.states {
if v.CreatedAt.Before(cutoff) {
delete(o.states, k)
}
}
o.statesMux.Unlock()
}
}