Files
netbird/proxy/internal/auth/oidc.go
2026-02-04 11:51:46 +00:00

164 lines
3.5 KiB
Go

package auth
import (
"context"
"net/http"
"net/url"
"strings"
"sync"
"time"
gojwt "github.com/golang-jwt/jwt/v5"
"github.com/netbirdio/netbird/shared/management/proto"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/shared/auth/jwt"
)
const stateExpiration = 10 * time.Minute
type urlGenerator interface {
GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error)
}
// OIDCConfig holds configuration for OIDC JWT verification
type OIDCConfig struct {
Issuer string
Audiences []string
KeysLocation string
MaxTokenAgeSeconds int64
}
// oidcState stores CSRF state with expiration
type oidcState struct {
OriginalURL string
CreatedAt time.Time
}
// OIDC implements the Scheme interface for JWT/OIDC authentication
type OIDC struct {
id, accountId string
validator *jwt.Validator
maxTokenAgeSeconds int64
client urlGenerator
states map[string]*oidcState
statesMux sync.RWMutex
}
// NewOIDC creates a new OIDC authentication scheme
func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC {
o := &OIDC{
id: id,
accountId: accountId,
validator: jwt.NewValidator(
cfg.Issuer,
cfg.Audiences,
cfg.KeysLocation,
true,
),
maxTokenAgeSeconds: cfg.MaxTokenAgeSeconds,
client: client,
states: make(map[string]*oidcState),
}
go o.cleanupStates()
return o
}
func (*OIDC) Type() Method {
return MethodOIDC
}
func (o *OIDC) Authenticate(r *http.Request) (string, string) {
// Try Authorization: Bearer <token> header
if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") {
if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" {
return userID, ""
}
}
redirectURL := &url.URL{
Scheme: "https",
Host: r.Host,
Path: r.URL.Path,
}
res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{
Id: o.id,
AccountId: o.accountId,
RedirectUrl: redirectURL.String(),
})
if err != nil {
// TODO: log
return "", ""
}
return "", res.GetUrl()
}
// validateToken validates a JWT ID token and returns the user ID (subject)
// Returns empty string if token is invalid.
func (o *OIDC) validateToken(ctx context.Context, token string) string {
if o.validator == nil {
return ""
}
idToken, err := o.validator.ValidateAndParse(ctx, token)
if err != nil {
// TODO: log or return?
return ""
}
iat, err := idToken.Claims.GetIssuedAt()
if err != nil {
// TODO: log or return?
return ""
}
if time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) {
// TODO: log or return?
return ""
}
return extractUserID(idToken)
}
func extractUserID(token *gojwt.Token) string {
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
}
// cleanupStates periodically removes expired states
func (o *OIDC) cleanupStates() {
for range time.Tick(time.Minute) {
cutoff := time.Now().Add(-stateExpiration)
o.statesMux.Lock()
for k, v := range o.states {
if v.CreatedAt.Before(cutoff) {
delete(o.states, k)
}
}
o.statesMux.Unlock()
}
}