Adding limit checks

This commit is contained in:
Owen
2026-02-10 10:05:37 -08:00
parent d814ad9f3e
commit 193b7ff21e
10 changed files with 207 additions and 64 deletions

View File

@@ -1,6 +1,3 @@
import Stripe from "stripe";
import { usageService } from "./usageService";
export enum FeatureId {
USERS = "users",
SITES = "sites",
@@ -135,25 +132,3 @@ export function getFeatureIdByPriceId(priceId: string): FeatureId | undefined {
return undefined;
}
export async function getLineItems(
featurePriceSet: FeaturePriceSet,
orgId: string,
): Promise<Stripe.Checkout.SessionCreateParams.LineItem[]> {
const users = await usageService.getUsage(orgId, FeatureId.USERS);
return Object.entries(featurePriceSet).map(([featureId, priceId]) => {
let quantity: number | undefined;
if (featureId === FeatureId.USERS) {
quantity = users?.instantaneousValue || 1;
} else if (featureId === FeatureId.TIER1) {
quantity = 1;
}
return {
price: priceId,
quantity: quantity
};
});
}

View File

@@ -0,0 +1,25 @@
import Stripe from "stripe";
import { FeatureId, FeaturePriceSet } from "./features";
import { usageService } from "./usageService";
export async function getLineItems(
featurePriceSet: FeaturePriceSet,
orgId: string,
): Promise<Stripe.Checkout.SessionCreateParams.LineItem[]> {
const users = await usageService.getUsage(orgId, FeatureId.USERS);
return Object.entries(featurePriceSet).map(([featureId, priceId]) => {
let quantity: number | undefined;
if (featureId === FeatureId.USERS) {
quantity = users?.instantaneousValue || 1;
} else if (featureId === FeatureId.TIER1) {
quantity = 1;
}
return {
price: priceId,
quantity: quantity
};
});
}

View File

@@ -29,3 +29,4 @@ export * from "./verifyUserIsOrgOwner";
export * from "./verifySiteResourceAccess";
export * from "./logActionAudit";
export * from "./verifyOlmAccess";
export * from "./verifyLimits";

View File

@@ -4,7 +4,6 @@ import { apiKeyOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
export async function verifyApiKeyOrgAccess(
req: Request,

View File

@@ -0,0 +1,47 @@
import { Request, Response, NextFunction } from "express";
import { db, orgs } from "@server/db";
import { userOrgs } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { usageService } from "@server/lib/billing/usageService";
import { build } from "@server/build";
export async function verifyLimits(
req: Request,
res: Response,
next: NextFunction
) {
if (build != "saas") {
return next();
}
const orgId = req.userOrgId || req.params.orgId;
if (!orgId) {
return next(); // its fine if we silently fail here because this is not critical to operation or security and its better user experience if we dont fail
}
try {
const reject = await usageService.checkLimitSet(orgId);
if (reject) {
return next(
createHttpError(
HttpCode.PAYMENT_REQUIRED,
"Organization has exceeded its usage limits. Please upgrade your plan or contact support."
)
);
}
return next();
} catch (e) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error checking limits"
)
);
}
}

View File

@@ -25,10 +25,10 @@ import {
getHomeLabFeaturePriceSet,
getScaleFeaturePriceSet,
getStarterFeaturePriceSet,
getLineItems,
FeatureId,
type FeaturePriceSet
} from "@server/lib/billing";
import { getLineItems } from "@server/lib/billing/getLineItems";
const changeTierSchema = z.strictObject({
orgId: z.string()
@@ -151,8 +151,10 @@ export async function changeTier(
// tier1 uses TIER1 product, tier2/tier3 use USERS product
const currentTier = subscription.type;
const switchingProducts =
(currentTier === "tier1" && (tier === "tier2" || tier === "tier3")) ||
((currentTier === "tier2" || currentTier === "tier3") && tier === "tier1");
(currentTier === "tier1" &&
(tier === "tier2" || tier === "tier3")) ||
((currentTier === "tier2" || currentTier === "tier3") &&
tier === "tier1");
let updatedSubscription;

View File

@@ -22,8 +22,12 @@ import logger from "@server/logger";
import config from "@server/lib/config";
import { fromError } from "zod-validation-error";
import stripe from "#private/lib/stripe";
import { getHomeLabFeaturePriceSet, getLineItems, getScaleFeaturePriceSet, getStarterFeaturePriceSet } from "@server/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import {
getHomeLabFeaturePriceSet,
getScaleFeaturePriceSet,
getStarterFeaturePriceSet
} from "@server/lib/billing";
import { getLineItems } from "@server/lib/billing/getLineItems";
import Stripe from "stripe";
const createCheckoutSessionSchema = z.strictObject({
@@ -31,7 +35,7 @@ const createCheckoutSessionSchema = z.strictObject({
});
const createCheckoutSessionBodySchema = z.strictObject({
tier: z.enum(["tier1", "tier2", "tier3"]),
tier: z.enum(["tier1", "tier2", "tier3"])
});
export async function createCheckoutSession(
@@ -90,12 +94,10 @@ export async function createCheckoutSession(
} else if (tier === "tier3") {
lineItems = await getLineItems(getScaleFeaturePriceSet(), orgId);
} else {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid plan")
);
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid plan"));
}
logger.debug(`Line items: ${JSON.stringify(lineItems)}`)
logger.debug(`Line items: ${JSON.stringify(lineItems)}`);
const session = await stripe!.checkout.sessions.create({
client_reference_id: orgId, // So we can look it up the org later on the webhook

View File

@@ -41,7 +41,8 @@ import {
verifyUserHasAction,
verifyUserIsOrgOwner,
verifySiteResourceAccess,
verifyOlmAccess
verifyOlmAccess,
verifyLimits
} from "@server/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import rateLimit, { ipKeyGenerator } from "express-rate-limit";
@@ -79,6 +80,7 @@ authenticated.get(
authenticated.post(
"/org/:orgId",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateOrg),
logActionAudit(ActionsEnum.updateOrg),
org.updateOrg
@@ -161,6 +163,7 @@ authenticated.get(
authenticated.put(
"/org/:orgId/client",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createClient),
logActionAudit(ActionsEnum.createClient),
client.createClient
@@ -178,6 +181,7 @@ authenticated.delete(
authenticated.post(
"/client/:clientId/archive",
verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.archiveClient),
logActionAudit(ActionsEnum.archiveClient),
client.archiveClient
@@ -186,6 +190,7 @@ authenticated.post(
authenticated.post(
"/client/:clientId/unarchive",
verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.unarchiveClient),
logActionAudit(ActionsEnum.unarchiveClient),
client.unarchiveClient
@@ -194,6 +199,7 @@ authenticated.post(
authenticated.post(
"/client/:clientId/block",
verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.blockClient),
logActionAudit(ActionsEnum.blockClient),
client.blockClient
@@ -202,6 +208,7 @@ authenticated.post(
authenticated.post(
"/client/:clientId/unblock",
verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.unblockClient),
logActionAudit(ActionsEnum.unblockClient),
client.unblockClient
@@ -210,6 +217,7 @@ authenticated.post(
authenticated.post(
"/client/:clientId",
verifyClientAccess, // this will check if the user has access to the client
verifyLimits,
verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client
logActionAudit(ActionsEnum.updateClient),
client.updateClient
@@ -224,6 +232,7 @@ authenticated.post(
authenticated.post(
"/site/:siteId",
verifySiteAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateSite),
logActionAudit(ActionsEnum.updateSite),
site.updateSite
@@ -273,6 +282,7 @@ authenticated.get(
authenticated.put(
"/org/:orgId/site-resource",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createSiteResource),
logActionAudit(ActionsEnum.createSiteResource),
siteResource.createSiteResource
@@ -303,6 +313,7 @@ authenticated.get(
authenticated.post(
"/site-resource/:siteResourceId",
verifySiteResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateSiteResource),
logActionAudit(ActionsEnum.updateSiteResource),
siteResource.updateSiteResource
@@ -341,6 +352,7 @@ authenticated.post(
"/site-resource/:siteResourceId/roles",
verifySiteResourceAccess,
verifyRoleAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles),
siteResource.setSiteResourceRoles
@@ -350,6 +362,7 @@ authenticated.post(
"/site-resource/:siteResourceId/users",
verifySiteResourceAccess,
verifySetResourceUsers,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceUsers
@@ -359,6 +372,7 @@ authenticated.post(
"/site-resource/:siteResourceId/clients",
verifySiteResourceAccess,
verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceClients
@@ -368,6 +382,7 @@ authenticated.post(
"/site-resource/:siteResourceId/clients/add",
verifySiteResourceAccess,
verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.addClientToSiteResource
@@ -377,6 +392,7 @@ authenticated.post(
"/site-resource/:siteResourceId/clients/remove",
verifySiteResourceAccess,
verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.removeClientFromSiteResource
@@ -385,6 +401,7 @@ authenticated.post(
authenticated.put(
"/org/:orgId/resource",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createResource),
logActionAudit(ActionsEnum.createResource),
resource.createResource
@@ -499,6 +516,7 @@ authenticated.get(
authenticated.post(
"/resource/:resourceId",
verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateResource),
logActionAudit(ActionsEnum.updateResource),
resource.updateResource
@@ -514,6 +532,7 @@ authenticated.delete(
authenticated.put(
"/resource/:resourceId/target",
verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createTarget),
logActionAudit(ActionsEnum.createTarget),
target.createTarget
@@ -528,6 +547,7 @@ authenticated.get(
authenticated.put(
"/resource/:resourceId/rule",
verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createResourceRule),
logActionAudit(ActionsEnum.createResourceRule),
resource.createResourceRule
@@ -577,6 +597,7 @@ authenticated.delete(
authenticated.put(
"/org/:orgId/role",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createRole),
logActionAudit(ActionsEnum.createRole),
role.createRole
@@ -774,6 +795,7 @@ authenticated.delete(
authenticated.put(
"/org/:orgId/user",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createOrgUser),
logActionAudit(ActionsEnum.createOrgUser),
user.createOrgUser
@@ -985,6 +1007,7 @@ authenticated.get(
authenticated.put(
`/org/:orgId/api-key`,
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createApiKey),
logActionAudit(ActionsEnum.createApiKey),
apiKeys.createOrgApiKey
@@ -1010,6 +1033,7 @@ authenticated.get(
authenticated.put(
`/org/:orgId/domain`,
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createOrgDomain),
logActionAudit(ActionsEnum.createOrgDomain),
domain.createOrgDomain
@@ -1065,6 +1089,7 @@ authenticated.get(
authenticated.put(
"/org/:orgId/blueprint",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.applyBlueprint),
blueprints.applyYAMLBlueprint
);

View File

@@ -146,7 +146,7 @@ authenticated.get(
);
// Site Resource endpoints
authenticated.put(
"/org/:orgId/private-resource",
"/org/:orgId/site-resource",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.createSiteResource),
logActionAudit(ActionsEnum.createSiteResource),