Switch to the new tier system and clean up checks

This commit is contained in:
Owen
2026-02-06 16:27:31 -08:00
parent 34cced872f
commit 6cfc7b7c69
31 changed files with 163 additions and 215 deletions

View File

@@ -1,11 +1,11 @@
import { FeatureId } from "./features"; import { FeatureId } from "./features";
export type LimitSet = { export type LimitSet = Partial<{
[key in FeatureId]: { [key in FeatureId]: {
value: number | null; // null indicates no limit value: number | null; // null indicates no limit
description?: string; description?: string;
}; };
}; }>;
export const sandboxLimitSet: LimitSet = { export const sandboxLimitSet: LimitSet = {
[FeatureId.SITES]: { value: 1, description: "Sandbox limit" }, // 1 site up for 2 days [FeatureId.SITES]: { value: 1, description: "Sandbox limit" }, // 1 site up for 2 days

View File

@@ -11,23 +11,58 @@
* This file is not licensed under the AGPLv3. * This file is not licensed under the AGPLv3.
*/ */
import { getTierPriceSet } from "@server/lib/billing/tiers";
import { getOrgSubscriptionsData } from "@server/private/routers/billing/getOrgSubscriptions";
import { build } from "@server/build"; import { build } from "@server/build";
import { db, customers, subscriptions } from "@server/db";
import { eq, and, ne } from "drizzle-orm";
export async function getOrgTierData( export async function getOrgTierData(
orgId: string orgId: string
): Promise<{ tier: string | null; active: boolean }> { ): Promise<{ tier: "home_lab" | "starter" | "scale" | null; active: boolean }> {
let tier = null; let tier: "home_lab" | "starter" | "scale" | null = null;
let active = false; let active = false;
if (build !== "saas") { if (build !== "saas") {
return { tier, active }; return { tier, active };
} }
// TODO: THIS IS INEFFICIENT!!! WE SHOULD IMPROVE HOW WE STORE TIERS WITH SUBSCRIPTIONS AND RETRIEVE THEM try {
// Get customer for org
const [customer] = await db
.select()
.from(customers)
.where(eq(customers.orgId, orgId))
.limit(1);
const subscriptionsWithItems = await getOrgSubscriptionsData(orgId); if (customer) {
// Query for active subscriptions that are not license type
const [subscription] = await db
.select()
.from(subscriptions)
.where(
and(
eq(subscriptions.customerId, customer.customerId),
eq(subscriptions.status, "active"),
ne(subscriptions.type, "license")
)
)
.limit(1);
if (subscription) {
// Validate that subscription.type is one of the expected tier values
if (
subscription.type === "home_lab" ||
subscription.type === "starter" ||
subscription.type === "scale"
) {
tier = subscription.type;
active = true;
}
}
}
} catch (error) {
// If org not found or error occurs, return null tier and inactive
// This is acceptable behavior as per the function signature
}
return { tier, active }; return { tier, active };
} }

View File

@@ -13,8 +13,6 @@
import { build } from "@server/build"; import { build } from "@server/build";
import { db, Org, orgs, ResourceSession, sessions, users } from "@server/db"; import { db, Org, orgs, ResourceSession, sessions, users } from "@server/db";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import license from "#private/license/license"; import license from "#private/license/license";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { import {

View File

@@ -14,7 +14,6 @@
import { build } from "@server/build"; import { build } from "@server/build";
import license from "#private/license/license"; import license from "#private/license/license";
import { getOrgTierData } from "#private/lib/billing"; import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
export async function isLicensedOrSubscribed(orgId: string): Promise<boolean> { export async function isLicensedOrSubscribed(orgId: string): Promise<boolean> {
if (build === "enterprise") { if (build === "enterprise") {
@@ -22,9 +21,9 @@ export async function isLicensedOrSubscribed(orgId: string): Promise<boolean> {
} }
if (build === "saas") { if (build === "saas") {
const { tier } = await getOrgTierData(orgId); const { tier, active } = await getOrgTierData(orgId);
return tier === TierId.STANDARD; return (tier == "home_lab" || tier == "starter" || tier == "scale") && active;
} }
return false; return false;
} }

View File

@@ -0,0 +1,24 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
export async function isSubscribed(orgId: string): Promise<boolean> {
if (build === "saas") {
const { tier, active } = await getOrgTierData(orgId);
return (tier == "home_lab" || tier == "starter" || tier == "scale") && active;
}
return false;
}

View File

@@ -38,9 +38,8 @@ export async function verifyValidSubscription(
); );
} }
const tier = await getOrgTierData(orgId); const { tier, active } = await getOrgTierData(orgId);
if ((tier == "home_lab" || tier == "starter" || tier == "scale") && active) {
if (!tier.active) {
return next( return next(
createHttpError( createHttpError(
HttpCode.FORBIDDEN, HttpCode.FORBIDDEN,

View File

@@ -19,8 +19,6 @@ import { fromError } from "zod-validation-error";
import type { Request, Response, NextFunction } from "express"; import type { Request, Response, NextFunction } from "express";
import { build } from "@server/build"; import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { import {
approvals, approvals,
clients, clients,
@@ -33,6 +31,7 @@ import {
import { eq, isNull, sql, not, and, desc } from "drizzle-orm"; import { eq, isNull, sql, not, and, desc } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
import { getUserDeviceName } from "@server/db/names"; import { getUserDeviceName } from "@server/db/names";
import { isLicensedOrSubscribed } from "@server/private/lib/isLicencedOrSubscribed";
const paramsSchema = z.strictObject({ const paramsSchema = z.strictObject({
orgId: z.string() orgId: z.string()
@@ -221,19 +220,6 @@ export async function listApprovals(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const approvalsList = await queryApprovals( const approvalsList = await queryApprovals(
orgId.toString(), orgId.toString(),
limit, limit,

View File

@@ -17,10 +17,7 @@ import createHttpError from "http-errors";
import { z } from "zod"; import { z } from "zod";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { build } from "@server/build";
import { approvals, clients, db, orgs, type Approval } from "@server/db"; import { approvals, clients, db, orgs, type Approval } from "@server/db";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import response from "@server/lib/response"; import response from "@server/lib/response";
import { and, eq, type InferInsertModel } from "drizzle-orm"; import { and, eq, type InferInsertModel } from "drizzle-orm";
import type { NextFunction, Request, Response } from "express"; import type { NextFunction, Request, Response } from "express";
@@ -64,20 +61,6 @@ export async function processPendingApproval(
} }
const { orgId, approvalId } = parsedParams.data; const { orgId, approvalId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const updateData = parsedBody.data; const updateData = parsedBody.data;
const approval = await db const approval = await db

View File

@@ -32,7 +32,7 @@ const createCheckoutSessionBodySchema = z.strictObject({
tier: z.enum(["home_lab", "starter", "scale"]), tier: z.enum(["home_lab", "starter", "scale"]),
}); });
export async function createCheckoutSessionSAAS( export async function createCheckoutSession(
req: Request, req: Request,
res: Response, res: Response,
next: NextFunction next: NextFunction

View File

@@ -139,7 +139,8 @@ export async function handleSubscriptionCreated(
// we only need to handle the limit lifecycle for saas subscriptions not for the licenses // we only need to handle the limit lifecycle for saas subscriptions not for the licenses
await handleSubscriptionLifesycle( await handleSubscriptionLifesycle(
customer.orgId, customer.orgId,
subscription.status subscription.status,
type
); );
const [orgUserRes] = await db const [orgUserRes] = await db

View File

@@ -76,14 +76,15 @@ export async function handleSubscriptionDeleted(
} }
const type = getSubType(fullSubscription); const type = getSubType(fullSubscription);
if (type === "saas") { if (type == "home_lab" || type == "starter" || type == "scale") {
logger.debug( logger.debug(
`Handling SaaS subscription deletion for orgId ${customer.orgId} and subscription ID ${subscription.id}` `Handling SaaS subscription deletion for orgId ${customer.orgId} and subscription ID ${subscription.id}`
); );
await handleSubscriptionLifesycle( await handleSubscriptionLifesycle(
customer.orgId, customer.orgId,
subscription.status subscription.status,
type
); );
const [orgUserRes] = await db const [orgUserRes] = await db

View File

@@ -244,7 +244,8 @@ export async function handleSubscriptionUpdated(
// we only need to handle the limit lifecycle for saas subscriptions not for the licenses // we only need to handle the limit lifecycle for saas subscriptions not for the licenses
await handleSubscriptionLifesycle( await handleSubscriptionLifesycle(
customer.orgId, customer.orgId,
subscription.status subscription.status,
type
); );
} else if (type === "license") { } else if (type === "license") {
if (subscription.status === "canceled" || subscription.status == "unpaid" || subscription.status == "incomplete_expired") { if (subscription.status === "canceled" || subscription.status == "unpaid" || subscription.status == "incomplete_expired") {

View File

@@ -13,36 +13,62 @@
import { import {
freeLimitSet, freeLimitSet,
homeLabLimitSet,
starterLimitSet,
scaleLimitSet,
limitsService, limitsService,
subscribedLimitSet LimitSet
} from "@server/lib/billing"; } from "@server/lib/billing";
import { usageService } from "@server/lib/billing/usageService"; import { usageService } from "@server/lib/billing/usageService";
import logger from "@server/logger"; import { SubscriptionType } from "./hooks/getSubType";
function getLimitSetForSubscriptionType(subType: SubscriptionType | null): LimitSet {
switch (subType) {
case "home_lab":
return homeLabLimitSet;
case "starter":
return starterLimitSet;
case "scale":
return scaleLimitSet;
case "license":
// License subscriptions use starter limits by default
// This can be adjusted based on your business logic
return starterLimitSet;
default:
return freeLimitSet;
}
}
export async function handleSubscriptionLifesycle( export async function handleSubscriptionLifesycle(
orgId: string, orgId: string,
status: string status: string,
subType: SubscriptionType | null
) { ) {
switch (status) { switch (status) {
case "active": case "active":
await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet); const activeLimitSet = getLimitSetForSubscriptionType(subType);
await limitsService.applyLimitSetToOrg(orgId, activeLimitSet);
await usageService.checkLimitSet(orgId, true); await usageService.checkLimitSet(orgId, true);
break; break;
case "canceled": case "canceled":
// Subscription canceled - revert to free tier
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet); await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
await usageService.checkLimitSet(orgId, true); await usageService.checkLimitSet(orgId, true);
break; break;
case "past_due": case "past_due":
// Optionally handle past due status, e.g., notify customer // Payment past due - keep current limits but notify customer
// Limits will revert to free tier if it becomes unpaid
break; break;
case "unpaid": case "unpaid":
// Subscription unpaid - revert to free tier
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet); await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
await usageService.checkLimitSet(orgId, true); await usageService.checkLimitSet(orgId, true);
break; break;
case "incomplete": case "incomplete":
// Optionally handle incomplete status, e.g., notify customer // Payment incomplete - give them time to complete payment
break; break;
case "incomplete_expired": case "incomplete_expired":
// Payment never completed - revert to free tier
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet); await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
await usageService.checkLimitSet(orgId, true); await usageService.checkLimitSet(orgId, true);
break; break;

View File

@@ -76,6 +76,7 @@ unauthenticated.post(
authenticated.put( authenticated.put(
"/org/:orgId/idp/oidc", "/org/:orgId/idp/oidc",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createIdp), verifyUserHasAction(ActionsEnum.createIdp),
logActionAudit(ActionsEnum.createIdp), logActionAudit(ActionsEnum.createIdp),
@@ -85,6 +86,7 @@ authenticated.put(
authenticated.post( authenticated.post(
"/org/:orgId/idp/:idpId/oidc", "/org/:orgId/idp/:idpId/oidc",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyIdpAccess, verifyIdpAccess,
verifyUserHasAction(ActionsEnum.updateIdp), verifyUserHasAction(ActionsEnum.updateIdp),
@@ -146,7 +148,7 @@ if (build === "saas") {
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.billing), verifyUserHasAction(ActionsEnum.billing),
logActionAudit(ActionsEnum.billing), logActionAudit(ActionsEnum.billing),
billing.createCheckoutSessionSAAS billing.createCheckoutSession
); );
authenticated.post( authenticated.post(
@@ -269,6 +271,7 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/org/:orgId/login-page", "/org/:orgId/login-page",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.createLoginPage), verifyUserHasAction(ActionsEnum.createLoginPage),
logActionAudit(ActionsEnum.createLoginPage), logActionAudit(ActionsEnum.createLoginPage),
@@ -278,6 +281,7 @@ authenticated.put(
authenticated.post( authenticated.post(
"/org/:orgId/login-page/:loginPageId", "/org/:orgId/login-page/:loginPageId",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyLoginPageAccess, verifyLoginPageAccess,
verifyUserHasAction(ActionsEnum.updateLoginPage), verifyUserHasAction(ActionsEnum.updateLoginPage),
@@ -306,6 +310,7 @@ authenticated.get(
authenticated.get( authenticated.get(
"/org/:orgId/approvals", "/org/:orgId/approvals",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.listApprovals), verifyUserHasAction(ActionsEnum.listApprovals),
logActionAudit(ActionsEnum.listApprovals), logActionAudit(ActionsEnum.listApprovals),
@@ -322,6 +327,7 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/approvals/:approvalId", "/org/:orgId/approvals/:approvalId",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.updateApprovals), verifyUserHasAction(ActionsEnum.updateApprovals),
logActionAudit(ActionsEnum.updateApprovals), logActionAudit(ActionsEnum.updateApprovals),
@@ -331,6 +337,7 @@ authenticated.put(
authenticated.get( authenticated.get(
"/org/:orgId/login-page-branding", "/org/:orgId/login-page-branding",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.getLoginPage), verifyUserHasAction(ActionsEnum.getLoginPage),
logActionAudit(ActionsEnum.getLoginPage), logActionAudit(ActionsEnum.getLoginPage),
@@ -340,6 +347,7 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/login-page-branding", "/org/:orgId/login-page-branding",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.updateLoginPage), verifyUserHasAction(ActionsEnum.updateLoginPage),
logActionAudit(ActionsEnum.updateLoginPage), logActionAudit(ActionsEnum.updateLoginPage),
@@ -349,6 +357,7 @@ authenticated.put(
authenticated.delete( authenticated.delete(
"/org/:orgId/login-page-branding", "/org/:orgId/login-page-branding",
verifyValidLicense, verifyValidLicense,
verifyValidSubscription,
verifyOrgAccess, verifyOrgAccess,
verifyUserHasAction(ActionsEnum.deleteLoginPage), verifyUserHasAction(ActionsEnum.deleteLoginPage),
logActionAudit(ActionsEnum.deleteLoginPage), logActionAudit(ActionsEnum.deleteLoginPage),

View File

@@ -30,9 +30,7 @@ import { fromError } from "zod-validation-error";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import { validateAndConstructDomain } from "@server/lib/domainUtils"; import { validateAndConstructDomain } from "@server/lib/domainUtils";
import { createCertificate } from "#private/routers/certificates/createCertificate"; import { createCertificate } from "#private/routers/certificates/createCertificate";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
import { CreateLoginPageResponse } from "@server/routers/loginPage/types"; import { CreateLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z.strictObject({ const paramsSchema = z.strictObject({
@@ -76,19 +74,6 @@ export async function createLoginPage(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existing] = await db const [existing] = await db
.select() .select()
.from(loginPageOrg) .from(loginPageOrg)

View File

@@ -25,9 +25,7 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
const paramsSchema = z const paramsSchema = z
.object({ .object({
@@ -53,18 +51,6 @@ export async function deleteLoginPageBranding(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existingLoginPageBranding] = await db const [existingLoginPageBranding] = await db
.select() .select()

View File

@@ -25,9 +25,7 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
const paramsSchema = z.strictObject({ const paramsSchema = z.strictObject({
orgId: z.string() orgId: z.string()
@@ -51,19 +49,6 @@ export async function getLoginPageBranding(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existingLoginPageBranding] = await db const [existingLoginPageBranding] = await db
.select() .select()
.from(loginPageBranding) .from(loginPageBranding)

View File

@@ -23,9 +23,7 @@ import { eq, and } from "drizzle-orm";
import { validateAndConstructDomain } from "@server/lib/domainUtils"; import { validateAndConstructDomain } from "@server/lib/domainUtils";
import { subdomainSchema } from "@server/lib/schemas"; import { subdomainSchema } from "@server/lib/schemas";
import { createCertificate } from "#private/routers/certificates/createCertificate"; import { createCertificate } from "#private/routers/certificates/createCertificate";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
import { UpdateLoginPageResponse } from "@server/routers/loginPage/types"; import { UpdateLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z const paramsSchema = z
@@ -87,18 +85,6 @@ export async function updateLoginPage(
const { loginPageId, orgId } = parsedParams.data; const { loginPageId, orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existingLoginPage] = await db const [existingLoginPage] = await db
.select() .select()

View File

@@ -25,8 +25,6 @@ import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { eq, InferInsertModel } from "drizzle-orm"; import { eq, InferInsertModel } from "drizzle-orm";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build"; import { build } from "@server/build";
import config from "@server/private/lib/config"; import config from "@server/private/lib/config";
@@ -128,19 +126,6 @@ export async function upsertLoginPageBranding(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
let updateData = parsedBody.data satisfies InferInsertModel< let updateData = parsedBody.data satisfies InferInsertModel<
typeof loginPageBranding typeof loginPageBranding
>; >;

View File

@@ -24,9 +24,6 @@ import { idp, idpOidcConfig, idpOrg, orgs } from "@server/db";
import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl"; import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
import { encrypt } from "@server/lib/crypto"; import { encrypt } from "@server/lib/crypto";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types"; import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types";
const paramsSchema = z.strictObject({ orgId: z.string().nonempty() }); const paramsSchema = z.strictObject({ orgId: z.string().nonempty() });
@@ -109,19 +106,6 @@ export async function createOrgOidcIdp(
tags tags
} = parsedBody.data; } = parsedBody.data;
if (build === "saas") {
const { tier, active } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const key = config.getRawConfig().server.secret!; const key = config.getRawConfig().server.secret!;
const encryptedSecret = encrypt(clientSecret, key); const encryptedSecret = encrypt(clientSecret, key);

View File

@@ -24,9 +24,6 @@ import { idp, idpOidcConfig } from "@server/db";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import { encrypt } from "@server/lib/crypto"; import { encrypt } from "@server/lib/crypto";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
const paramsSchema = z const paramsSchema = z
.object({ .object({
@@ -114,19 +111,6 @@ export async function updateOrgOidcIdp(
tags tags
} = parsedBody.data; } = parsedBody.data;
if (build === "saas") {
const { tier, active } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
// Check if IDP exists and is of type OIDC // Check if IDP exists and is of type OIDC
const [existingIdp] = await db const [existingIdp] = await db
.select() .select()

View File

@@ -18,7 +18,6 @@ import {
ResourcePassword, ResourcePassword,
ResourcePincode, ResourcePincode,
ResourceRule, ResourceRule,
resourceSessions
} from "@server/db"; } from "@server/db";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { isIpInCidr, stripPortFromHost } from "@server/lib/ip"; import { isIpInCidr, stripPortFromHost } from "@server/lib/ip";
@@ -32,7 +31,6 @@ import { fromError } from "zod-validation-error";
import { getCountryCodeForIp } from "@server/lib/geoip"; import { getCountryCodeForIp } from "@server/lib/geoip";
import { getAsnForIp } from "@server/lib/asn"; import { getAsnForIp } from "@server/lib/asn";
import { getOrgTierData } from "#dynamic/lib/billing"; import { getOrgTierData } from "#dynamic/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { verifyPassword } from "@server/auth/password"; import { verifyPassword } from "@server/auth/password";
import { import {
checkOrgAccessPolicy, checkOrgAccessPolicy,
@@ -40,8 +38,8 @@ import {
} from "#dynamic/lib/checkOrgAccessPolicy"; } from "#dynamic/lib/checkOrgAccessPolicy";
import { logRequestAudit } from "./logRequestAudit"; import { logRequestAudit } from "./logRequestAudit";
import cache from "@server/lib/cache"; import cache from "@server/lib/cache";
import semver from "semver";
import { APP_VERSION } from "@server/lib/consts"; import { APP_VERSION } from "@server/lib/consts";
import { isSubscribed } from "#private/lib/isSubscribed";
const verifyResourceSessionSchema = z.object({ const verifyResourceSessionSchema = z.object({
sessions: z.record(z.string(), z.string()).optional(), sessions: z.record(z.string(), z.string()).optional(),
@@ -798,8 +796,8 @@ async function notAllowed(
) { ) {
let loginPage: LoginPage | null = null; let loginPage: LoginPage | null = null;
if (orgId) { if (orgId) {
const { tier } = await getOrgTierData(orgId); // returns null in oss const subscribed = await isSubscribed(orgId);
if (tier === TierId.STANDARD) { if (subscribed) {
loginPage = await getOrgLoginPage(orgId); loginPage = await getOrgLoginPage(orgId);
} }
} }
@@ -852,8 +850,8 @@ async function headerAuthChallenged(
) { ) {
let loginPage: LoginPage | null = null; let loginPage: LoginPage | null = null;
if (orgId) { if (orgId) {
const { tier } = await getOrgTierData(orgId); // returns null in oss const subscribed = await isSubscribed(orgId);
if (tier === TierId.STANDARD) { if (subscribed) {
loginPage = await getOrgLoginPage(orgId); loginPage = await getOrgLoginPage(orgId);
} }
} }

View File

@@ -14,8 +14,7 @@ import jsonwebtoken from "jsonwebtoken";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { decrypt } from "@server/lib/crypto"; import { decrypt } from "@server/lib/crypto";
import { build } from "@server/build"; import { build } from "@server/build";
import { getOrgTierData } from "#dynamic/lib/billing"; import { isSubscribed } from "@server/private/lib/isSubscribed";
import { TierId } from "@server/lib/billing/tiers";
const paramsSchema = z const paramsSchema = z
.object({ .object({
@@ -113,8 +112,7 @@ export async function generateOidcUrl(
} }
if (build === "saas") { if (build === "saas") {
const { tier } = await getOrgTierData(orgId); const subscribed = await isSubscribed(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) { if (!subscribed) {
return next( return next(
createHttpError( createHttpError(

View File

@@ -10,10 +10,10 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { build } from "@server/build"; import { build } from "@server/build";
import { getOrgTierData } from "#dynamic/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { cache } from "@server/lib/cache"; import { cache } from "@server/lib/cache";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed"; import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { subscribe } from "node:diagnostics_channel";
import { isSubscribed } from "@server/private/lib/isSubscribed";
const updateOrgParamsSchema = z.strictObject({ const updateOrgParamsSchema = z.strictObject({
orgId: z.string() orgId: z.string()
@@ -95,10 +95,10 @@ export async function updateOrg(
parsedBody.data.passwordExpiryDays = undefined; parsedBody.data.passwordExpiryDays = undefined;
} }
const { tier } = await getOrgTierData(orgId); const subscribed = await isSubscribed(orgId);
if ( if (
build == "saas" && build == "saas" &&
tier != TierId.STANDARD && subscribed &&
parsedBody.data.settingsLogRetentionDaysRequest && parsedBody.data.settingsLogRetentionDaysRequest &&
parsedBody.data.settingsLogRetentionDaysRequest > 30 parsedBody.data.settingsLogRetentionDaysRequest > 30
) { ) {

View File

@@ -13,9 +13,8 @@ import { generateId } from "@server/auth/sessions/app";
import { usageService } from "@server/lib/billing/usageService"; import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing"; import { FeatureId } from "@server/lib/billing";
import { build } from "@server/build"; import { build } from "@server/build";
import { getOrgTierData } from "#dynamic/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { isSubscribed } from "@server/private/lib/isSubscribed";
const paramsSchema = z.strictObject({ const paramsSchema = z.strictObject({
orgId: z.string().nonempty() orgId: z.string().nonempty()
@@ -132,9 +131,8 @@ export async function createOrgUser(
); );
} else if (type === "oidc") { } else if (type === "oidc") {
if (build === "saas") { if (build === "saas") {
const { tier } = await getOrgTierData(orgId); const subscribed = await isSubscribed(orgId);
const subscribed = tier === TierId.STANDARD; if (subscribed) {
if (!subscribed) {
return next( return next(
createHttpError( createHttpError(
HttpCode.FORBIDDEN, HttpCode.FORBIDDEN,

View File

@@ -76,7 +76,6 @@ export default function GeneralPage() {
setAllSubscriptions(subscriptions); setAllSubscriptions(subscriptions);
// Import tier and license price sets // Import tier and license price sets
const { getTierPriceSet } = await import("@server/lib/billing/tiers");
const { getLicensePriceSet } = await import("@server/lib/billing/licenses"); const { getLicensePriceSet } = await import("@server/lib/billing/licenses");
const tierPriceSet = getTierPriceSet( const tierPriceSet = getTierPriceSet(

View File

@@ -48,7 +48,6 @@ import { useTranslations } from "next-intl";
import { build } from "@server/build"; import { build } from "@server/build";
import Image from "next/image"; import Image from "next/image";
import { useSubscriptionStatusContext } from "@app/hooks/useSubscriptionStatusContext"; import { useSubscriptionStatusContext } from "@app/hooks/useSubscriptionStatusContext";
import { TierId } from "@server/lib/billing/tiers";
type UserType = "internal" | "oidc"; type UserType = "internal" | "oidc";

View File

@@ -23,8 +23,6 @@ import type {
LoadLoginPageBrandingResponse, LoadLoginPageBrandingResponse,
LoadLoginPageResponse LoadLoginPageResponse
} from "@server/routers/loginPage/types"; } from "@server/routers/loginPage/types";
import { GetOrgTierResponse } from "@server/routers/billing/types";
import { TierId } from "@server/lib/billing/tiers";
import { CheckOrgUserAccessResponse } from "@server/routers/org"; import { CheckOrgUserAccessResponse } from "@server/routers/org";
import OrgPolicyRequired from "@app/components/OrgPolicyRequired"; import OrgPolicyRequired from "@app/components/OrgPolicyRequired";
import { isOrgSubscribed } from "@app/lib/api/isOrgSubscribed"; import { isOrgSubscribed } from "@app/lib/api/isOrgSubscribed";

View File

@@ -5,7 +5,7 @@ type SubscriptionStatusContextType = {
subscriptionStatus: GetOrgSubscriptionResponse | null; subscriptionStatus: GetOrgSubscriptionResponse | null;
updateSubscriptionStatus: (updatedSite: GetOrgSubscriptionResponse) => void; updateSubscriptionStatus: (updatedSite: GetOrgSubscriptionResponse) => void;
isActive: () => boolean; isActive: () => boolean;
getTier: () => string | null; getTier: () => { tier: string | null; active: boolean };
isSubscribed: () => boolean; isSubscribed: () => boolean;
subscribed: boolean; subscribed: boolean;
}; };

View File

@@ -1,5 +1,4 @@
import { build } from "@server/build"; import { build } from "@server/build";
import { TierId } from "@server/lib/billing/tiers";
import { cache } from "react"; import { cache } from "react";
import { getCachedSubscription } from "./getCachedSubscription"; import { getCachedSubscription } from "./getCachedSubscription";
import { priv } from "."; import { priv } from ".";
@@ -21,7 +20,7 @@ export const isOrgSubscribed = cache(async (orgId: string) => {
try { try {
const subRes = await getCachedSubscription(orgId); const subRes = await getCachedSubscription(orgId);
subscribed = subscribed =
subRes.data.data.tier === TierId.STANDARD && (subRes.data.data.tier == "home_lab" || subRes.data.data.tier == "starter" || subRes.data.data.tier == "scale") &&
subRes.data.data.active; subRes.data.data.active;
} catch {} } catch {}
} }

View File

@@ -1,7 +1,6 @@
"use client"; "use client";
import SubscriptionStatusContext from "@app/contexts/subscriptionStatusContext"; import SubscriptionStatusContext from "@app/contexts/subscriptionStatusContext";
import { getTierPriceSet, TierId } from "@server/lib/billing/tiers";
import { GetOrgSubscriptionResponse } from "@server/routers/billing/types"; import { GetOrgSubscriptionResponse } from "@server/routers/billing/types";
import { useState } from "react"; import { useState } from "react";
import { build } from "@server/build"; import { build } from "@server/build";
@@ -43,34 +42,37 @@ export function SubscriptionStatusProvider({
}; };
const getTier = () => { const getTier = () => {
const tierPriceSet = getTierPriceSet(env, sandbox_mode);
if (subscriptionStatus?.subscriptions) { if (subscriptionStatus?.subscriptions) {
// Iterate through all subscriptions // Iterate through all subscriptions
for (const { subscription, items } of subscriptionStatus.subscriptions) { for (const { subscription } of subscriptionStatus.subscriptions) {
if (items && items.length > 0) { if (
// Iterate through tiers in order (earlier keys are higher tiers) subscription.type == "home_lab" ||
for (const [tierId, priceId] of Object.entries(tierPriceSet)) { subscription.type == "starter" ||
// Check if any subscription item matches this tier's price ID subscription.type == "scale"
const matchingItem = items.find( ) {
(item) => item.priceId === priceId return {
); tier: subscription.type,
if (matchingItem) { active: subscription.status === "active"
return tierId; };
}
}
} }
} }
} }
return null; return {
tier: null,
active: false
};
}; };
const isSubscribed = () => { const isSubscribed = () => {
if (build === "enterprise") { if (build === "enterprise") {
return true; return true;
} }
return getTier() === TierId.STANDARD; const { tier, active } = getTier();
return (
(tier == "home_lab" || tier == "starter" || tier == "scale") &&
active
);
}; };
const [subscribed, setSubscribed] = useState<boolean>(isSubscribed()); const [subscribed, setSubscribed] = useState<boolean>(isSubscribed());
@@ -91,4 +93,4 @@ export function SubscriptionStatusProvider({
); );
} }
export default SubscriptionStatusProvider; export default SubscriptionStatusProvider;