package auth import ( "context" "net/http" "net/url" "strings" "sync" "time" gojwt "github.com/golang-jwt/jwt/v5" "github.com/netbirdio/netbird/shared/management/proto" "google.golang.org/grpc" "github.com/netbirdio/netbird/shared/auth/jwt" ) const stateExpiration = 10 * time.Minute type urlGenerator interface { GetOIDCURL(context.Context, *proto.GetOIDCURLRequest, ...grpc.CallOption) (*proto.GetOIDCURLResponse, error) } // OIDCConfig holds configuration for OIDC JWT verification type OIDCConfig struct { Issuer string Audiences []string KeysLocation string MaxTokenAgeSeconds int64 } // oidcState stores CSRF state with expiration type oidcState struct { OriginalURL string CreatedAt time.Time } // OIDC implements the Scheme interface for JWT/OIDC authentication type OIDC struct { id, accountId string validator *jwt.Validator maxTokenAgeSeconds int64 client urlGenerator states map[string]*oidcState statesMux sync.RWMutex } // NewOIDC creates a new OIDC authentication scheme func NewOIDC(client urlGenerator, id, accountId string, cfg OIDCConfig) *OIDC { o := &OIDC{ id: id, accountId: accountId, validator: jwt.NewValidator( cfg.Issuer, cfg.Audiences, cfg.KeysLocation, true, ), maxTokenAgeSeconds: cfg.MaxTokenAgeSeconds, client: client, states: make(map[string]*oidcState), } go o.cleanupStates() return o } func (*OIDC) Type() Method { return MethodOIDC } func (o *OIDC) Authenticate(r *http.Request) (string, string) { // Try Authorization: Bearer header if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") { if userID := o.validateToken(r.Context(), strings.TrimPrefix(auth, "Bearer ")); userID != "" { return userID, "" } } redirectURL := &url.URL{ Scheme: "https", Host: r.Host, Path: r.URL.Path, } res, err := o.client.GetOIDCURL(r.Context(), &proto.GetOIDCURLRequest{ Id: o.id, AccountId: o.accountId, RedirectUrl: redirectURL.String(), }) if err != nil { // TODO: log return "", "" } return "", res.GetUrl() } // validateToken validates a JWT ID token and returns the user ID (subject) // Returns empty string if token is invalid. func (o *OIDC) validateToken(ctx context.Context, token string) string { if o.validator == nil { return "" } idToken, err := o.validator.ValidateAndParse(ctx, token) if err != nil { // TODO: log or return? return "" } iat, err := idToken.Claims.GetIssuedAt() if err != nil { // TODO: log or return? return "" } if time.Since(iat.Time).Seconds() > float64(o.maxTokenAgeSeconds) { // TODO: log or return? return "" } return extractUserID(idToken) } func extractUserID(token *gojwt.Token) string { if token == nil { return "unknown" } claims, ok := token.Claims.(gojwt.MapClaims) if !ok { return "unknown" } return getUserIDFromClaims(claims) } func getUserIDFromClaims(claims gojwt.MapClaims) string { if sub, ok := claims["sub"].(string); ok && sub != "" { return sub } if userID, ok := claims["user_id"].(string); ok && userID != "" { return userID } if email, ok := claims["email"].(string); ok && email != "" { return email } return "unknown" } // cleanupStates periodically removes expired states func (o *OIDC) cleanupStates() { for range time.Tick(time.Minute) { cutoff := time.Now().Add(-stateExpiration) o.statesMux.Lock() for k, v := range o.states { if v.CreatedAt.Before(cutoff) { delete(o.states, k) } } o.statesMux.Unlock() } }