From 6cfc7b7c699b45d6a3d982b136727a66a9410e57 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 6 Feb 2026 16:27:31 -0800 Subject: [PATCH] Switch to the new tier system and clean up checks --- server/lib/billing/limitSet.ts | 4 +- server/private/lib/billing/getOrgTierData.ts | 47 ++++++++++++++++--- server/private/lib/checkOrgAccessPolicy.ts | 2 - server/private/lib/isLicencedOrSubscribed.ts | 7 ++- server/private/lib/isSubscribed.ts | 24 ++++++++++ .../private/middlewares/verifySubscription.ts | 5 +- .../routers/approvals/listApprovals.ts | 16 +------ .../approvals/processPendingApproval.ts | 17 ------- .../routers/billing/createCheckoutSession.ts | 2 +- .../hooks/handleSubscriptionCreated.ts | 3 +- .../hooks/handleSubscriptionDeleted.ts | 5 +- .../hooks/handleSubscriptionUpdated.ts | 3 +- .../routers/billing/subscriptionLifecycle.ts | 38 ++++++++++++--- server/private/routers/external.ts | 11 ++++- .../routers/loginPage/createLoginPage.ts | 17 +------ .../loginPage/deleteLoginPageBranding.ts | 16 +------ .../routers/loginPage/getLoginPageBranding.ts | 17 +------ .../routers/loginPage/updateLoginPage.ts | 16 +------ .../loginPage/upsertLoginPageBranding.ts | 15 ------ .../routers/orgIdp/createOrgOidcIdp.ts | 16 ------- .../routers/orgIdp/updateOrgOidcIdp.ts | 16 ------- server/routers/badger/verifySession.ts | 12 ++--- server/routers/idp/generateOidcUrl.ts | 6 +-- server/routers/org/updateOrg.ts | 8 ++-- server/routers/user/createOrgUser.ts | 8 ++-- .../settings/(private)/billing/page.tsx | 1 - .../settings/access/users/create/page.tsx | 1 - src/app/auth/resource/[resourceGuid]/page.tsx | 2 - src/contexts/subscriptionStatusContext.ts | 2 +- src/lib/api/isOrgSubscribed.ts | 3 +- src/providers/SubscriptionStatusProvider.tsx | 38 ++++++++------- 31 files changed, 163 insertions(+), 215 deletions(-) create mode 100644 server/private/lib/isSubscribed.ts diff --git a/server/lib/billing/limitSet.ts b/server/lib/billing/limitSet.ts index 12aea306..0419262c 100644 --- a/server/lib/billing/limitSet.ts +++ b/server/lib/billing/limitSet.ts @@ -1,11 +1,11 @@ import { FeatureId } from "./features"; -export type LimitSet = { +export type LimitSet = Partial<{ [key in FeatureId]: { value: number | null; // null indicates no limit description?: string; }; -}; +}>; export const sandboxLimitSet: LimitSet = { [FeatureId.SITES]: { value: 1, description: "Sandbox limit" }, // 1 site up for 2 days diff --git a/server/private/lib/billing/getOrgTierData.ts b/server/private/lib/billing/getOrgTierData.ts index 68e7ea2c..174148f6 100644 --- a/server/private/lib/billing/getOrgTierData.ts +++ b/server/private/lib/billing/getOrgTierData.ts @@ -11,23 +11,58 @@ * This file is not licensed under the AGPLv3. */ -import { getTierPriceSet } from "@server/lib/billing/tiers"; -import { getOrgSubscriptionsData } from "@server/private/routers/billing/getOrgSubscriptions"; import { build } from "@server/build"; +import { db, customers, subscriptions } from "@server/db"; +import { eq, and, ne } from "drizzle-orm"; export async function getOrgTierData( orgId: string -): Promise<{ tier: string | null; active: boolean }> { - let tier = null; +): Promise<{ tier: "home_lab" | "starter" | "scale" | null; active: boolean }> { + let tier: "home_lab" | "starter" | "scale" | null = null; let active = false; if (build !== "saas") { return { tier, active }; } - // TODO: THIS IS INEFFICIENT!!! WE SHOULD IMPROVE HOW WE STORE TIERS WITH SUBSCRIPTIONS AND RETRIEVE THEM + try { + // Get customer for org + const [customer] = await db + .select() + .from(customers) + .where(eq(customers.orgId, orgId)) + .limit(1); - const subscriptionsWithItems = await getOrgSubscriptionsData(orgId); + if (customer) { + // Query for active subscriptions that are not license type + const [subscription] = await db + .select() + .from(subscriptions) + .where( + and( + eq(subscriptions.customerId, customer.customerId), + eq(subscriptions.status, "active"), + ne(subscriptions.type, "license") + ) + ) + .limit(1); + + if (subscription) { + // Validate that subscription.type is one of the expected tier values + if ( + subscription.type === "home_lab" || + subscription.type === "starter" || + subscription.type === "scale" + ) { + tier = subscription.type; + active = true; + } + } + } + } catch (error) { + // If org not found or error occurs, return null tier and inactive + // This is acceptable behavior as per the function signature + } return { tier, active }; } diff --git a/server/private/lib/checkOrgAccessPolicy.ts b/server/private/lib/checkOrgAccessPolicy.ts index cb40c8b8..af318ce0 100644 --- a/server/private/lib/checkOrgAccessPolicy.ts +++ b/server/private/lib/checkOrgAccessPolicy.ts @@ -13,8 +13,6 @@ import { build } from "@server/build"; import { db, Org, orgs, ResourceSession, sessions, users } from "@server/db"; -import { getOrgTierData } from "#private/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; import license from "#private/license/license"; import { eq } from "drizzle-orm"; import { diff --git a/server/private/lib/isLicencedOrSubscribed.ts b/server/private/lib/isLicencedOrSubscribed.ts index 494deb7a..2e8c04fa 100644 --- a/server/private/lib/isLicencedOrSubscribed.ts +++ b/server/private/lib/isLicencedOrSubscribed.ts @@ -14,7 +14,6 @@ import { build } from "@server/build"; import license from "#private/license/license"; import { getOrgTierData } from "#private/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; export async function isLicensedOrSubscribed(orgId: string): Promise { if (build === "enterprise") { @@ -22,9 +21,9 @@ export async function isLicensedOrSubscribed(orgId: string): Promise { } if (build === "saas") { - const { tier } = await getOrgTierData(orgId); - return tier === TierId.STANDARD; + const { tier, active } = await getOrgTierData(orgId); + return (tier == "home_lab" || tier == "starter" || tier == "scale") && active; } return false; -} \ No newline at end of file +} diff --git a/server/private/lib/isSubscribed.ts b/server/private/lib/isSubscribed.ts new file mode 100644 index 00000000..9ff71bca --- /dev/null +++ b/server/private/lib/isSubscribed.ts @@ -0,0 +1,24 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + +import { build } from "@server/build"; +import { getOrgTierData } from "#private/lib/billing"; + +export async function isSubscribed(orgId: string): Promise { + if (build === "saas") { + const { tier, active } = await getOrgTierData(orgId); + return (tier == "home_lab" || tier == "starter" || tier == "scale") && active; + } + + return false; +} diff --git a/server/private/middlewares/verifySubscription.ts b/server/private/middlewares/verifySubscription.ts index 8cda737e..f1b7a0ce 100644 --- a/server/private/middlewares/verifySubscription.ts +++ b/server/private/middlewares/verifySubscription.ts @@ -38,9 +38,8 @@ export async function verifyValidSubscription( ); } - const tier = await getOrgTierData(orgId); - - if (!tier.active) { + const { tier, active } = await getOrgTierData(orgId); + if ((tier == "home_lab" || tier == "starter" || tier == "scale") && active) { return next( createHttpError( HttpCode.FORBIDDEN, diff --git a/server/private/routers/approvals/listApprovals.ts b/server/private/routers/approvals/listApprovals.ts index 600eec87..509df5eb 100644 --- a/server/private/routers/approvals/listApprovals.ts +++ b/server/private/routers/approvals/listApprovals.ts @@ -19,8 +19,6 @@ import { fromError } from "zod-validation-error"; import type { Request, Response, NextFunction } from "express"; import { build } from "@server/build"; -import { getOrgTierData } from "#private/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; import { approvals, clients, @@ -33,6 +31,7 @@ import { import { eq, isNull, sql, not, and, desc } from "drizzle-orm"; import response from "@server/lib/response"; import { getUserDeviceName } from "@server/db/names"; +import { isLicensedOrSubscribed } from "@server/private/lib/isLicencedOrSubscribed"; const paramsSchema = z.strictObject({ orgId: z.string() @@ -221,19 +220,6 @@ export async function listApprovals( const { orgId } = parsedParams.data; - if (build === "saas") { - const { tier } = await getOrgTierData(orgId); - const subscribed = tier === TierId.STANDARD; - if (!subscribed) { - return next( - createHttpError( - HttpCode.FORBIDDEN, - "This organization's current plan does not support this feature." - ) - ); - } - } - const approvalsList = await queryApprovals( orgId.toString(), limit, diff --git a/server/private/routers/approvals/processPendingApproval.ts b/server/private/routers/approvals/processPendingApproval.ts index d4988ac5..fa60445f 100644 --- a/server/private/routers/approvals/processPendingApproval.ts +++ b/server/private/routers/approvals/processPendingApproval.ts @@ -17,10 +17,7 @@ import createHttpError from "http-errors"; import { z } from "zod"; import { fromError } from "zod-validation-error"; -import { build } from "@server/build"; import { approvals, clients, db, orgs, type Approval } from "@server/db"; -import { getOrgTierData } from "#private/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; import response from "@server/lib/response"; import { and, eq, type InferInsertModel } from "drizzle-orm"; import type { NextFunction, Request, Response } from "express"; @@ -64,20 +61,6 @@ export async function processPendingApproval( } const { orgId, approvalId } = parsedParams.data; - - if (build === "saas") { - const { tier } = await getOrgTierData(orgId); - const subscribed = tier === TierId.STANDARD; - if (!subscribed) { - return next( - createHttpError( - HttpCode.FORBIDDEN, - "This organization's current plan does not support this feature." - ) - ); - } - } - const updateData = parsedBody.data; const approval = await db diff --git a/server/private/routers/billing/createCheckoutSession.ts b/server/private/routers/billing/createCheckoutSession.ts index 7c79b8fb..a90c5e86 100644 --- a/server/private/routers/billing/createCheckoutSession.ts +++ b/server/private/routers/billing/createCheckoutSession.ts @@ -32,7 +32,7 @@ const createCheckoutSessionBodySchema = z.strictObject({ tier: z.enum(["home_lab", "starter", "scale"]), }); -export async function createCheckoutSessionSAAS( +export async function createCheckoutSession( req: Request, res: Response, next: NextFunction diff --git a/server/private/routers/billing/hooks/handleSubscriptionCreated.ts b/server/private/routers/billing/hooks/handleSubscriptionCreated.ts index 6238e65c..4553c986 100644 --- a/server/private/routers/billing/hooks/handleSubscriptionCreated.ts +++ b/server/private/routers/billing/hooks/handleSubscriptionCreated.ts @@ -139,7 +139,8 @@ export async function handleSubscriptionCreated( // we only need to handle the limit lifecycle for saas subscriptions not for the licenses await handleSubscriptionLifesycle( customer.orgId, - subscription.status + subscription.status, + type ); const [orgUserRes] = await db diff --git a/server/private/routers/billing/hooks/handleSubscriptionDeleted.ts b/server/private/routers/billing/hooks/handleSubscriptionDeleted.ts index 003110aa..a35b3be6 100644 --- a/server/private/routers/billing/hooks/handleSubscriptionDeleted.ts +++ b/server/private/routers/billing/hooks/handleSubscriptionDeleted.ts @@ -76,14 +76,15 @@ export async function handleSubscriptionDeleted( } const type = getSubType(fullSubscription); - if (type === "saas") { + if (type == "home_lab" || type == "starter" || type == "scale") { logger.debug( `Handling SaaS subscription deletion for orgId ${customer.orgId} and subscription ID ${subscription.id}` ); await handleSubscriptionLifesycle( customer.orgId, - subscription.status + subscription.status, + type ); const [orgUserRes] = await db diff --git a/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts b/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts index 6b21fb26..8f808944 100644 --- a/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts +++ b/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts @@ -244,7 +244,8 @@ export async function handleSubscriptionUpdated( // we only need to handle the limit lifecycle for saas subscriptions not for the licenses await handleSubscriptionLifesycle( customer.orgId, - subscription.status + subscription.status, + type ); } else if (type === "license") { if (subscription.status === "canceled" || subscription.status == "unpaid" || subscription.status == "incomplete_expired") { diff --git a/server/private/routers/billing/subscriptionLifecycle.ts b/server/private/routers/billing/subscriptionLifecycle.ts index 0fc75835..73a58748 100644 --- a/server/private/routers/billing/subscriptionLifecycle.ts +++ b/server/private/routers/billing/subscriptionLifecycle.ts @@ -13,36 +13,62 @@ import { freeLimitSet, + homeLabLimitSet, + starterLimitSet, + scaleLimitSet, limitsService, - subscribedLimitSet + LimitSet } from "@server/lib/billing"; import { usageService } from "@server/lib/billing/usageService"; -import logger from "@server/logger"; +import { SubscriptionType } from "./hooks/getSubType"; + +function getLimitSetForSubscriptionType(subType: SubscriptionType | null): LimitSet { + switch (subType) { + case "home_lab": + return homeLabLimitSet; + case "starter": + return starterLimitSet; + case "scale": + return scaleLimitSet; + case "license": + // License subscriptions use starter limits by default + // This can be adjusted based on your business logic + return starterLimitSet; + default: + return freeLimitSet; + } +} export async function handleSubscriptionLifesycle( orgId: string, - status: string + status: string, + subType: SubscriptionType | null ) { switch (status) { case "active": - await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet); + const activeLimitSet = getLimitSetForSubscriptionType(subType); + await limitsService.applyLimitSetToOrg(orgId, activeLimitSet); await usageService.checkLimitSet(orgId, true); break; case "canceled": + // Subscription canceled - revert to free tier await limitsService.applyLimitSetToOrg(orgId, freeLimitSet); await usageService.checkLimitSet(orgId, true); break; case "past_due": - // Optionally handle past due status, e.g., notify customer + // Payment past due - keep current limits but notify customer + // Limits will revert to free tier if it becomes unpaid break; case "unpaid": + // Subscription unpaid - revert to free tier await limitsService.applyLimitSetToOrg(orgId, freeLimitSet); await usageService.checkLimitSet(orgId, true); break; case "incomplete": - // Optionally handle incomplete status, e.g., notify customer + // Payment incomplete - give them time to complete payment break; case "incomplete_expired": + // Payment never completed - revert to free tier await limitsService.applyLimitSetToOrg(orgId, freeLimitSet); await usageService.checkLimitSet(orgId, true); break; diff --git a/server/private/routers/external.ts b/server/private/routers/external.ts index 962cede9..0ef9077d 100644 --- a/server/private/routers/external.ts +++ b/server/private/routers/external.ts @@ -76,6 +76,7 @@ unauthenticated.post( authenticated.put( "/org/:orgId/idp/oidc", verifyValidLicense, + verifyValidSubscription, verifyOrgAccess, verifyUserHasAction(ActionsEnum.createIdp), logActionAudit(ActionsEnum.createIdp), @@ -85,6 +86,7 @@ authenticated.put( authenticated.post( "/org/:orgId/idp/:idpId/oidc", verifyValidLicense, + verifyValidSubscription, verifyOrgAccess, verifyIdpAccess, verifyUserHasAction(ActionsEnum.updateIdp), @@ -146,7 +148,7 @@ if (build === "saas") { verifyOrgAccess, verifyUserHasAction(ActionsEnum.billing), logActionAudit(ActionsEnum.billing), - billing.createCheckoutSessionSAAS + billing.createCheckoutSession ); authenticated.post( @@ -269,6 +271,7 @@ authenticated.delete( authenticated.put( "/org/:orgId/login-page", verifyValidLicense, + verifyValidSubscription, verifyOrgAccess, verifyUserHasAction(ActionsEnum.createLoginPage), logActionAudit(ActionsEnum.createLoginPage), @@ -278,6 +281,7 @@ authenticated.put( authenticated.post( "/org/:orgId/login-page/:loginPageId", verifyValidLicense, + verifyValidSubscription, verifyOrgAccess, verifyLoginPageAccess, verifyUserHasAction(ActionsEnum.updateLoginPage), @@ -306,6 +310,7 @@ authenticated.get( authenticated.get( "/org/:orgId/approvals", verifyValidLicense, + verifyValidSubscription, verifyOrgAccess, verifyUserHasAction(ActionsEnum.listApprovals), logActionAudit(ActionsEnum.listApprovals), @@ -322,6 +327,7 @@ authenticated.get( authenticated.put( "/org/:orgId/approvals/:approvalId", verifyValidLicense, + verifyValidSubscription, verifyOrgAccess, verifyUserHasAction(ActionsEnum.updateApprovals), logActionAudit(ActionsEnum.updateApprovals), @@ -331,6 +337,7 @@ authenticated.put( authenticated.get( "/org/:orgId/login-page-branding", verifyValidLicense, + verifyValidSubscription, verifyOrgAccess, verifyUserHasAction(ActionsEnum.getLoginPage), logActionAudit(ActionsEnum.getLoginPage), @@ -340,6 +347,7 @@ authenticated.get( authenticated.put( "/org/:orgId/login-page-branding", verifyValidLicense, + verifyValidSubscription, verifyOrgAccess, verifyUserHasAction(ActionsEnum.updateLoginPage), logActionAudit(ActionsEnum.updateLoginPage), @@ -349,6 +357,7 @@ authenticated.put( authenticated.delete( "/org/:orgId/login-page-branding", verifyValidLicense, + verifyValidSubscription, verifyOrgAccess, verifyUserHasAction(ActionsEnum.deleteLoginPage), logActionAudit(ActionsEnum.deleteLoginPage), diff --git a/server/private/routers/loginPage/createLoginPage.ts b/server/private/routers/loginPage/createLoginPage.ts index b5e8ccff..72b8a28f 100644 --- a/server/private/routers/loginPage/createLoginPage.ts +++ b/server/private/routers/loginPage/createLoginPage.ts @@ -30,9 +30,7 @@ import { fromError } from "zod-validation-error"; import { eq, and } from "drizzle-orm"; import { validateAndConstructDomain } from "@server/lib/domainUtils"; import { createCertificate } from "#private/routers/certificates/createCertificate"; -import { getOrgTierData } from "#private/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; -import { build } from "@server/build"; + import { CreateLoginPageResponse } from "@server/routers/loginPage/types"; const paramsSchema = z.strictObject({ @@ -76,19 +74,6 @@ export async function createLoginPage( const { orgId } = parsedParams.data; - if (build === "saas") { - const { tier } = await getOrgTierData(orgId); - const subscribed = tier === TierId.STANDARD; - if (!subscribed) { - return next( - createHttpError( - HttpCode.FORBIDDEN, - "This organization's current plan does not support this feature." - ) - ); - } - } - const [existing] = await db .select() .from(loginPageOrg) diff --git a/server/private/routers/loginPage/deleteLoginPageBranding.ts b/server/private/routers/loginPage/deleteLoginPageBranding.ts index 1fb243b0..0a59ce4e 100644 --- a/server/private/routers/loginPage/deleteLoginPageBranding.ts +++ b/server/private/routers/loginPage/deleteLoginPageBranding.ts @@ -25,9 +25,7 @@ import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq } from "drizzle-orm"; -import { getOrgTierData } from "#private/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; -import { build } from "@server/build"; + const paramsSchema = z .object({ @@ -53,18 +51,6 @@ export async function deleteLoginPageBranding( const { orgId } = parsedParams.data; - if (build === "saas") { - const { tier } = await getOrgTierData(orgId); - const subscribed = tier === TierId.STANDARD; - if (!subscribed) { - return next( - createHttpError( - HttpCode.FORBIDDEN, - "This organization's current plan does not support this feature." - ) - ); - } - } const [existingLoginPageBranding] = await db .select() diff --git a/server/private/routers/loginPage/getLoginPageBranding.ts b/server/private/routers/loginPage/getLoginPageBranding.ts index 8fd0772d..ce133c7c 100644 --- a/server/private/routers/loginPage/getLoginPageBranding.ts +++ b/server/private/routers/loginPage/getLoginPageBranding.ts @@ -25,9 +25,7 @@ import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq } from "drizzle-orm"; -import { getOrgTierData } from "#private/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; -import { build } from "@server/build"; + const paramsSchema = z.strictObject({ orgId: z.string() @@ -51,19 +49,6 @@ export async function getLoginPageBranding( const { orgId } = parsedParams.data; - if (build === "saas") { - const { tier } = await getOrgTierData(orgId); - const subscribed = tier === TierId.STANDARD; - if (!subscribed) { - return next( - createHttpError( - HttpCode.FORBIDDEN, - "This organization's current plan does not support this feature." - ) - ); - } - } - const [existingLoginPageBranding] = await db .select() .from(loginPageBranding) diff --git a/server/private/routers/loginPage/updateLoginPage.ts b/server/private/routers/loginPage/updateLoginPage.ts index bda614d3..6226dda2 100644 --- a/server/private/routers/loginPage/updateLoginPage.ts +++ b/server/private/routers/loginPage/updateLoginPage.ts @@ -23,9 +23,7 @@ import { eq, and } from "drizzle-orm"; import { validateAndConstructDomain } from "@server/lib/domainUtils"; import { subdomainSchema } from "@server/lib/schemas"; import { createCertificate } from "#private/routers/certificates/createCertificate"; -import { getOrgTierData } from "#private/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; -import { build } from "@server/build"; + import { UpdateLoginPageResponse } from "@server/routers/loginPage/types"; const paramsSchema = z @@ -87,18 +85,6 @@ export async function updateLoginPage( const { loginPageId, orgId } = parsedParams.data; - if (build === "saas") { - const { tier } = await getOrgTierData(orgId); - const subscribed = tier === TierId.STANDARD; - if (!subscribed) { - return next( - createHttpError( - HttpCode.FORBIDDEN, - "This organization's current plan does not support this feature." - ) - ); - } - } const [existingLoginPage] = await db .select() diff --git a/server/private/routers/loginPage/upsertLoginPageBranding.ts b/server/private/routers/loginPage/upsertLoginPageBranding.ts index e6e365be..e81628dc 100644 --- a/server/private/routers/loginPage/upsertLoginPageBranding.ts +++ b/server/private/routers/loginPage/upsertLoginPageBranding.ts @@ -25,8 +25,6 @@ import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, InferInsertModel } from "drizzle-orm"; -import { getOrgTierData } from "#private/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; import { build } from "@server/build"; import config from "@server/private/lib/config"; @@ -128,19 +126,6 @@ export async function upsertLoginPageBranding( const { orgId } = parsedParams.data; - if (build === "saas") { - const { tier } = await getOrgTierData(orgId); - const subscribed = tier === TierId.STANDARD; - if (!subscribed) { - return next( - createHttpError( - HttpCode.FORBIDDEN, - "This organization's current plan does not support this feature." - ) - ); - } - } - let updateData = parsedBody.data satisfies InferInsertModel< typeof loginPageBranding >; diff --git a/server/private/routers/orgIdp/createOrgOidcIdp.ts b/server/private/routers/orgIdp/createOrgOidcIdp.ts index 998a159f..bee04340 100644 --- a/server/private/routers/orgIdp/createOrgOidcIdp.ts +++ b/server/private/routers/orgIdp/createOrgOidcIdp.ts @@ -24,9 +24,6 @@ import { idp, idpOidcConfig, idpOrg, orgs } from "@server/db"; import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl"; import { encrypt } from "@server/lib/crypto"; import config from "@server/lib/config"; -import { build } from "@server/build"; -import { getOrgTierData } from "#private/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types"; const paramsSchema = z.strictObject({ orgId: z.string().nonempty() }); @@ -109,19 +106,6 @@ export async function createOrgOidcIdp( tags } = parsedBody.data; - if (build === "saas") { - const { tier, active } = await getOrgTierData(orgId); - const subscribed = tier === TierId.STANDARD; - if (!subscribed) { - return next( - createHttpError( - HttpCode.FORBIDDEN, - "This organization's current plan does not support this feature." - ) - ); - } - } - const key = config.getRawConfig().server.secret!; const encryptedSecret = encrypt(clientSecret, key); diff --git a/server/private/routers/orgIdp/updateOrgOidcIdp.ts b/server/private/routers/orgIdp/updateOrgOidcIdp.ts index d8ef415c..e01bdba0 100644 --- a/server/private/routers/orgIdp/updateOrgOidcIdp.ts +++ b/server/private/routers/orgIdp/updateOrgOidcIdp.ts @@ -24,9 +24,6 @@ import { idp, idpOidcConfig } from "@server/db"; import { eq, and } from "drizzle-orm"; import { encrypt } from "@server/lib/crypto"; import config from "@server/lib/config"; -import { build } from "@server/build"; -import { getOrgTierData } from "#private/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; const paramsSchema = z .object({ @@ -114,19 +111,6 @@ export async function updateOrgOidcIdp( tags } = parsedBody.data; - if (build === "saas") { - const { tier, active } = await getOrgTierData(orgId); - const subscribed = tier === TierId.STANDARD; - if (!subscribed) { - return next( - createHttpError( - HttpCode.FORBIDDEN, - "This organization's current plan does not support this feature." - ) - ); - } - } - // Check if IDP exists and is of type OIDC const [existingIdp] = await db .select() diff --git a/server/routers/badger/verifySession.ts b/server/routers/badger/verifySession.ts index 3226755d..fa81b6f9 100644 --- a/server/routers/badger/verifySession.ts +++ b/server/routers/badger/verifySession.ts @@ -18,7 +18,6 @@ import { ResourcePassword, ResourcePincode, ResourceRule, - resourceSessions } from "@server/db"; import config from "@server/lib/config"; import { isIpInCidr, stripPortFromHost } from "@server/lib/ip"; @@ -32,7 +31,6 @@ import { fromError } from "zod-validation-error"; import { getCountryCodeForIp } from "@server/lib/geoip"; import { getAsnForIp } from "@server/lib/asn"; import { getOrgTierData } from "#dynamic/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; import { verifyPassword } from "@server/auth/password"; import { checkOrgAccessPolicy, @@ -40,8 +38,8 @@ import { } from "#dynamic/lib/checkOrgAccessPolicy"; import { logRequestAudit } from "./logRequestAudit"; import cache from "@server/lib/cache"; -import semver from "semver"; import { APP_VERSION } from "@server/lib/consts"; +import { isSubscribed } from "#private/lib/isSubscribed"; const verifyResourceSessionSchema = z.object({ sessions: z.record(z.string(), z.string()).optional(), @@ -798,8 +796,8 @@ async function notAllowed( ) { let loginPage: LoginPage | null = null; if (orgId) { - const { tier } = await getOrgTierData(orgId); // returns null in oss - if (tier === TierId.STANDARD) { + const subscribed = await isSubscribed(orgId); + if (subscribed) { loginPage = await getOrgLoginPage(orgId); } } @@ -852,8 +850,8 @@ async function headerAuthChallenged( ) { let loginPage: LoginPage | null = null; if (orgId) { - const { tier } = await getOrgTierData(orgId); // returns null in oss - if (tier === TierId.STANDARD) { + const subscribed = await isSubscribed(orgId); + if (subscribed) { loginPage = await getOrgLoginPage(orgId); } } diff --git a/server/routers/idp/generateOidcUrl.ts b/server/routers/idp/generateOidcUrl.ts index 50b63ee5..5743631b 100644 --- a/server/routers/idp/generateOidcUrl.ts +++ b/server/routers/idp/generateOidcUrl.ts @@ -14,8 +14,7 @@ import jsonwebtoken from "jsonwebtoken"; import config from "@server/lib/config"; import { decrypt } from "@server/lib/crypto"; import { build } from "@server/build"; -import { getOrgTierData } from "#dynamic/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; +import { isSubscribed } from "@server/private/lib/isSubscribed"; const paramsSchema = z .object({ @@ -113,8 +112,7 @@ export async function generateOidcUrl( } if (build === "saas") { - const { tier } = await getOrgTierData(orgId); - const subscribed = tier === TierId.STANDARD; + const subscribed = await isSubscribed(orgId); if (!subscribed) { return next( createHttpError( diff --git a/server/routers/org/updateOrg.ts b/server/routers/org/updateOrg.ts index 44ff9190..38ffab18 100644 --- a/server/routers/org/updateOrg.ts +++ b/server/routers/org/updateOrg.ts @@ -10,10 +10,10 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { OpenAPITags, registry } from "@server/openApi"; import { build } from "@server/build"; -import { getOrgTierData } from "#dynamic/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; import { cache } from "@server/lib/cache"; import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; +import { subscribe } from "node:diagnostics_channel"; +import { isSubscribed } from "@server/private/lib/isSubscribed"; const updateOrgParamsSchema = z.strictObject({ orgId: z.string() @@ -95,10 +95,10 @@ export async function updateOrg( parsedBody.data.passwordExpiryDays = undefined; } - const { tier } = await getOrgTierData(orgId); + const subscribed = await isSubscribed(orgId); if ( build == "saas" && - tier != TierId.STANDARD && + subscribed && parsedBody.data.settingsLogRetentionDaysRequest && parsedBody.data.settingsLogRetentionDaysRequest > 30 ) { diff --git a/server/routers/user/createOrgUser.ts b/server/routers/user/createOrgUser.ts index b9a1abc9..3fe72a35 100644 --- a/server/routers/user/createOrgUser.ts +++ b/server/routers/user/createOrgUser.ts @@ -13,9 +13,8 @@ import { generateId } from "@server/auth/sessions/app"; import { usageService } from "@server/lib/billing/usageService"; import { FeatureId } from "@server/lib/billing"; import { build } from "@server/build"; -import { getOrgTierData } from "#dynamic/lib/billing"; -import { TierId } from "@server/lib/billing/tiers"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; +import { isSubscribed } from "@server/private/lib/isSubscribed"; const paramsSchema = z.strictObject({ orgId: z.string().nonempty() @@ -132,9 +131,8 @@ export async function createOrgUser( ); } else if (type === "oidc") { if (build === "saas") { - const { tier } = await getOrgTierData(orgId); - const subscribed = tier === TierId.STANDARD; - if (!subscribed) { + const subscribed = await isSubscribed(orgId); + if (subscribed) { return next( createHttpError( HttpCode.FORBIDDEN, diff --git a/src/app/[orgId]/settings/(private)/billing/page.tsx b/src/app/[orgId]/settings/(private)/billing/page.tsx index 75611f35..259d8a66 100644 --- a/src/app/[orgId]/settings/(private)/billing/page.tsx +++ b/src/app/[orgId]/settings/(private)/billing/page.tsx @@ -76,7 +76,6 @@ export default function GeneralPage() { setAllSubscriptions(subscriptions); // Import tier and license price sets - const { getTierPriceSet } = await import("@server/lib/billing/tiers"); const { getLicensePriceSet } = await import("@server/lib/billing/licenses"); const tierPriceSet = getTierPriceSet( diff --git a/src/app/[orgId]/settings/access/users/create/page.tsx b/src/app/[orgId]/settings/access/users/create/page.tsx index 0e55ffeb..7d8ad2cd 100644 --- a/src/app/[orgId]/settings/access/users/create/page.tsx +++ b/src/app/[orgId]/settings/access/users/create/page.tsx @@ -48,7 +48,6 @@ import { useTranslations } from "next-intl"; import { build } from "@server/build"; import Image from "next/image"; import { useSubscriptionStatusContext } from "@app/hooks/useSubscriptionStatusContext"; -import { TierId } from "@server/lib/billing/tiers"; type UserType = "internal" | "oidc"; diff --git a/src/app/auth/resource/[resourceGuid]/page.tsx b/src/app/auth/resource/[resourceGuid]/page.tsx index e2e0330c..5f9b1183 100644 --- a/src/app/auth/resource/[resourceGuid]/page.tsx +++ b/src/app/auth/resource/[resourceGuid]/page.tsx @@ -23,8 +23,6 @@ import type { LoadLoginPageBrandingResponse, LoadLoginPageResponse } from "@server/routers/loginPage/types"; -import { GetOrgTierResponse } from "@server/routers/billing/types"; -import { TierId } from "@server/lib/billing/tiers"; import { CheckOrgUserAccessResponse } from "@server/routers/org"; import OrgPolicyRequired from "@app/components/OrgPolicyRequired"; import { isOrgSubscribed } from "@app/lib/api/isOrgSubscribed"; diff --git a/src/contexts/subscriptionStatusContext.ts b/src/contexts/subscriptionStatusContext.ts index 5b333db4..71fe7004 100644 --- a/src/contexts/subscriptionStatusContext.ts +++ b/src/contexts/subscriptionStatusContext.ts @@ -5,7 +5,7 @@ type SubscriptionStatusContextType = { subscriptionStatus: GetOrgSubscriptionResponse | null; updateSubscriptionStatus: (updatedSite: GetOrgSubscriptionResponse) => void; isActive: () => boolean; - getTier: () => string | null; + getTier: () => { tier: string | null; active: boolean }; isSubscribed: () => boolean; subscribed: boolean; }; diff --git a/src/lib/api/isOrgSubscribed.ts b/src/lib/api/isOrgSubscribed.ts index 9440330b..8eb4b8e8 100644 --- a/src/lib/api/isOrgSubscribed.ts +++ b/src/lib/api/isOrgSubscribed.ts @@ -1,5 +1,4 @@ import { build } from "@server/build"; -import { TierId } from "@server/lib/billing/tiers"; import { cache } from "react"; import { getCachedSubscription } from "./getCachedSubscription"; import { priv } from "."; @@ -21,7 +20,7 @@ export const isOrgSubscribed = cache(async (orgId: string) => { try { const subRes = await getCachedSubscription(orgId); subscribed = - subRes.data.data.tier === TierId.STANDARD && + (subRes.data.data.tier == "home_lab" || subRes.data.data.tier == "starter" || subRes.data.data.tier == "scale") && subRes.data.data.active; } catch {} } diff --git a/src/providers/SubscriptionStatusProvider.tsx b/src/providers/SubscriptionStatusProvider.tsx index eecafce8..f9d8ef8b 100644 --- a/src/providers/SubscriptionStatusProvider.tsx +++ b/src/providers/SubscriptionStatusProvider.tsx @@ -1,7 +1,6 @@ "use client"; import SubscriptionStatusContext from "@app/contexts/subscriptionStatusContext"; -import { getTierPriceSet, TierId } from "@server/lib/billing/tiers"; import { GetOrgSubscriptionResponse } from "@server/routers/billing/types"; import { useState } from "react"; import { build } from "@server/build"; @@ -43,34 +42,37 @@ export function SubscriptionStatusProvider({ }; const getTier = () => { - const tierPriceSet = getTierPriceSet(env, sandbox_mode); - if (subscriptionStatus?.subscriptions) { // Iterate through all subscriptions - for (const { subscription, items } of subscriptionStatus.subscriptions) { - if (items && items.length > 0) { - // Iterate through tiers in order (earlier keys are higher tiers) - for (const [tierId, priceId] of Object.entries(tierPriceSet)) { - // Check if any subscription item matches this tier's price ID - const matchingItem = items.find( - (item) => item.priceId === priceId - ); - if (matchingItem) { - return tierId; - } - } + for (const { subscription } of subscriptionStatus.subscriptions) { + if ( + subscription.type == "home_lab" || + subscription.type == "starter" || + subscription.type == "scale" + ) { + return { + tier: subscription.type, + active: subscription.status === "active" + }; } } } - return null; + return { + tier: null, + active: false + }; }; const isSubscribed = () => { if (build === "enterprise") { return true; } - return getTier() === TierId.STANDARD; + const { tier, active } = getTier(); + return ( + (tier == "home_lab" || tier == "starter" || tier == "scale") && + active + ); }; const [subscribed, setSubscribed] = useState(isSubscribed()); @@ -91,4 +93,4 @@ export function SubscriptionStatusProvider({ ); } -export default SubscriptionStatusProvider; \ No newline at end of file +export default SubscriptionStatusProvider;