mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
management OIDC implementation using pkce
This commit is contained in:
@@ -50,8 +50,3 @@ func (l Link) Authenticate(r *http.Request) (string, string) {
|
||||
|
||||
return "", linkFormId
|
||||
}
|
||||
|
||||
func (l Link) Middleware(next http.Handler) http.Handler {
|
||||
// TODO: handle magic link redirects, should be similar to OIDC.
|
||||
return next
|
||||
}
|
||||
|
||||
@@ -56,12 +56,6 @@ type Scheme interface {
|
||||
// be included in a UI template when prompting the user to authenticate.
|
||||
// If the request is authenticated, then a user id should be returned.
|
||||
Authenticate(*http.Request) (userid string, promptData string)
|
||||
// Middleware is applied within the outer auth middleware, but they will
|
||||
// be applied after authentication if no scheme has authenticated a
|
||||
// request.
|
||||
// If no scheme Middleware blocks the request processing, then the auth
|
||||
// middleware will then present the user with the auth UI.
|
||||
Middleware(http.Handler) http.Handler
|
||||
}
|
||||
|
||||
type Middleware struct {
|
||||
@@ -137,26 +131,13 @@ func (mw *Middleware) Protect(next http.Handler) http.Handler {
|
||||
methods[s.Type().String()] = promptData
|
||||
}
|
||||
|
||||
// The handler is passed through the scheme middlewares,
|
||||
// if none of them intercept the request, then this handler will
|
||||
// be called and present the user with the authentication page.
|
||||
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := tmpl.Execute(w, struct {
|
||||
Methods map[string]string
|
||||
}{
|
||||
Methods: methods,
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
}))
|
||||
|
||||
// No authentication succeeded. Apply the scheme handlers.
|
||||
for _, s := range schemes {
|
||||
handler = s.Middleware(handler)
|
||||
if err := tmpl.Execute(w, struct {
|
||||
Methods map[string]string
|
||||
}{
|
||||
Methods: methods,
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
}
|
||||
|
||||
// Run the unauthenticated request against the scheme handlers and the final UI handler.
|
||||
handler.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -2,30 +2,31 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"golang.org/x/oauth2"
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
const stateExpiration = 10 * time.Minute
|
||||
|
||||
const callbackPath = "/oauth/callback"
|
||||
type urlGenerator interface {
|
||||
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
|
||||
}
|
||||
|
||||
// OIDCConfig holds configuration for OIDC authentication
|
||||
// OIDCConfig holds configuration for OIDC JWT verification
|
||||
type OIDCConfig struct {
|
||||
OIDCProviderURL string
|
||||
OIDCClientID string
|
||||
OIDCClientSecret string
|
||||
OIDCScopes []string
|
||||
DistributionGroups []string
|
||||
Issuer string
|
||||
Audiences []string
|
||||
KeysLocation string
|
||||
MaxTokenAgeSeconds int64
|
||||
}
|
||||
|
||||
// oidcState stores CSRF state with expiration
|
||||
@@ -36,50 +37,33 @@ type oidcState struct {
|
||||
|
||||
// OIDC implements the Scheme interface for JWT/OIDC authentication
|
||||
type OIDC struct {
|
||||
id, accountId, proxyURL string
|
||||
verifier *oidc.IDTokenVerifier
|
||||
oauthConfig *oauth2.Config
|
||||
states map[string]*oidcState
|
||||
statesMux sync.RWMutex
|
||||
distributionGroups []string
|
||||
id, accountId string
|
||||
validator *jwt.Validator
|
||||
maxTokenAgeSeconds int64
|
||||
client urlGenerator
|
||||
states map[string]*oidcState
|
||||
statesMux sync.RWMutex
|
||||
}
|
||||
|
||||
// NewOIDC creates a new OIDC authentication scheme
|
||||
func NewOIDC(ctx context.Context, id, accountId, proxyURL string, cfg OIDCConfig) (*OIDC, error) {
|
||||
if cfg.OIDCProviderURL == "" || cfg.OIDCClientID == "" {
|
||||
return nil, fmt.Errorf("OIDC provider URL and client ID are required")
|
||||
}
|
||||
|
||||
scopes := cfg.OIDCScopes
|
||||
if len(scopes) == 0 {
|
||||
scopes = []string{oidc.ScopeOpenID, "profile", "email"}
|
||||
}
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, cfg.OIDCProviderURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create OIDC provider: %w", err)
|
||||
}
|
||||
|
||||
func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
|
||||
o := &OIDC{
|
||||
id: id,
|
||||
accountId: accountId,
|
||||
proxyURL: proxyURL,
|
||||
verifier: provider.Verifier(&oidc.Config{
|
||||
ClientID: cfg.OIDCClientID,
|
||||
}),
|
||||
oauthConfig: &oauth2.Config{
|
||||
ClientID: cfg.OIDCClientID,
|
||||
ClientSecret: cfg.OIDCClientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
Scopes: scopes,
|
||||
},
|
||||
validator: jwt.NewValidator(
|
||||
cfg.Issuer,
|
||||
cfg.Audiences,
|
||||
cfg.KeysLocation,
|
||||
true,
|
||||
),
|
||||
maxTokenAgeSeconds: cfg.MaxTokenAgeSeconds,
|
||||
client: client,
|
||||
states: make(map[string]*oidcState),
|
||||
distributionGroups: cfg.DistributionGroups,
|
||||
}
|
||||
|
||||
go o.cleanupStates()
|
||||
|
||||
return o, nil
|
||||
return o
|
||||
}
|
||||
|
||||
func (*OIDC) Type() Method {
|
||||
@@ -94,134 +78,74 @@ func (o *OIDC) Authenticate(r *http.Request) (string, string) {
|
||||
}
|
||||
}
|
||||
|
||||
// Try _auth_token query parameter (from OIDC callback redirect)
|
||||
if token := r.URL.Query().Get("_auth_token"); token != "" {
|
||||
if userID := o.validateToken(r.Context(), token); userID != "" {
|
||||
return userID, ""
|
||||
}
|
||||
redirectURL := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: r.Host,
|
||||
Path: r.URL.Path,
|
||||
}
|
||||
|
||||
// If the request is not authenticated, return a redirect URL for the UI to
|
||||
// route the user through if they select OIDC login.
|
||||
b := make([]byte, 32)
|
||||
_, _ = rand.Read(b)
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
// TODO: this does not work if you are load balancing across multiple proxy servers.
|
||||
o.statesMux.Lock()
|
||||
o.states[state] = &oidcState{OriginalURL: fmt.Sprintf("https://%s%s", r.Host, r.URL), CreatedAt: time.Now()}
|
||||
o.statesMux.Unlock()
|
||||
|
||||
return "", (&oauth2.Config{
|
||||
ClientID: o.oauthConfig.ClientID,
|
||||
ClientSecret: o.oauthConfig.ClientSecret,
|
||||
Endpoint: o.oauthConfig.Endpoint,
|
||||
RedirectURL: o.proxyURL + callbackPath,
|
||||
Scopes: o.oauthConfig.Scopes,
|
||||
}).AuthCodeURL(state)
|
||||
}
|
||||
|
||||
// Middleware returns an http.Handler that handles OIDC callback and flow initiation.
|
||||
func (o *OIDC) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Handle OIDC callback
|
||||
if r.URL.Path == callbackPath {
|
||||
o.handleCallback(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{
|
||||
Id: o.id,
|
||||
AccountId: o.accountId,
|
||||
RedirectUrl: redirectURL.String(),
|
||||
})
|
||||
if err != nil {
|
||||
// TODO: log
|
||||
return "", ""
|
||||
}
|
||||
|
||||
return "", res.GetUrl()
|
||||
}
|
||||
|
||||
// validateToken validates a JWT ID token and returns the user ID (subject)
|
||||
// Returns empty string if token is invalid or user's groups don't appear
|
||||
// in the distributionGroups.
|
||||
// Returns empty string if token is invalid.
|
||||
func (o *OIDC) validateToken(ctx context.Context, token string) string {
|
||||
if o.verifier == nil {
|
||||
if o.validator == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
idToken, err := o.verifier.Verify(ctx, token)
|
||||
idToken, err := o.validator.ValidateAndParse(ctx, token)
|
||||
if err != nil {
|
||||
// TODO: log or return?
|
||||
return ""
|
||||
}
|
||||
|
||||
// If distribution groups are configured, check if user has access
|
||||
if len(o.distributionGroups) > 0 {
|
||||
var claims struct {
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
// TODO: log or return?
|
||||
return ""
|
||||
}
|
||||
|
||||
allowed := make(map[string]struct{}, len(o.distributionGroups))
|
||||
for _, g := range o.distributionGroups {
|
||||
allowed[g] = struct{}{}
|
||||
}
|
||||
|
||||
for _, g := range claims.Groups {
|
||||
if _, ok := allowed[g]; ok {
|
||||
return idToken.Subject
|
||||
}
|
||||
}
|
||||
iat, err := idToken.Claims.GetIssuedAt()
|
||||
if err != nil {
|
||||
// TODO: log or return?
|
||||
return ""
|
||||
}
|
||||
|
||||
// Default deny
|
||||
return ""
|
||||
if time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
|
||||
// TODO: log or return?
|
||||
return ""
|
||||
}
|
||||
|
||||
return extractUserID(idToken)
|
||||
}
|
||||
|
||||
// handleCallback processes the OIDC callback
|
||||
func (o *OIDC) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" || state == "" {
|
||||
http.Error(w, "Invalid callback parameters", http.StatusBadRequest)
|
||||
return
|
||||
func extractUserID(token *gojwt.Token) string {
|
||||
if token == nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Verify and consume state
|
||||
o.statesMux.Lock()
|
||||
st, ok := o.states[state]
|
||||
if ok {
|
||||
delete(o.states, state)
|
||||
}
|
||||
o.statesMux.Unlock()
|
||||
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
http.Error(w, "Invalid or expired state", http.StatusBadRequest)
|
||||
return
|
||||
return "unknown"
|
||||
}
|
||||
return getUserIDFromClaims(claims)
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
token, err := o.oauthConfig.Exchange(r.Context(), code)
|
||||
if err != nil {
|
||||
http.Error(w, "Authentication failed", http.StatusUnauthorized)
|
||||
return
|
||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
|
||||
// Prefer ID token if available
|
||||
idToken := token.AccessToken
|
||||
if id, ok := token.Extra("id_token").(string); ok && id != "" {
|
||||
idToken = id
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
|
||||
// Redirect back to original URL with token
|
||||
origURL, err := url.Parse(st.OriginalURL)
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
|
||||
q := origURL.Query()
|
||||
q.Set("_auth_token", idToken)
|
||||
origURL.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, origURL.String(), http.StatusFound)
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// cleanupStates periodically removes expired states
|
||||
|
||||
@@ -36,6 +36,11 @@ func (Password) Type() Method {
|
||||
func (p Password) Authenticate(r *http.Request) (string, string) {
|
||||
password := r.FormValue(passwordFormId)
|
||||
|
||||
if password == "" {
|
||||
// This cannot be authenticated so not worth wasting time sending the request.
|
||||
return "", passwordFormId
|
||||
}
|
||||
|
||||
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
|
||||
Id: p.id,
|
||||
AccountId: p.accountId,
|
||||
@@ -56,7 +61,3 @@ func (p Password) Authenticate(r *http.Request) (string, string) {
|
||||
|
||||
return "", passwordFormId
|
||||
}
|
||||
|
||||
func (p Password) Middleware(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
|
||||
@@ -36,6 +36,11 @@ func (Pin) Type() Method {
|
||||
func (p Pin) Authenticate(r *http.Request) (string, string) {
|
||||
pin := r.FormValue(pinFormId)
|
||||
|
||||
if pin == "" {
|
||||
// This cannot be authenticated so not worth wasting time sending the request.
|
||||
return "", pinFormId
|
||||
}
|
||||
|
||||
res, err := p.client.Authenticate(r.Context(), &proto.AuthenticateRequest{
|
||||
Id: p.id,
|
||||
AccountId: p.accountId,
|
||||
@@ -56,7 +61,3 @@ func (p Pin) Authenticate(r *http.Request) (string, string) {
|
||||
|
||||
return "", pinFormId
|
||||
}
|
||||
|
||||
func (p Pin) Middleware(next http.Handler) http.Handler {
|
||||
return next
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user