Adding limit checks

This commit is contained in:
Owen
2026-02-10 10:05:37 -08:00
parent 9d0ff472e5
commit f318f6304b
10 changed files with 207 additions and 64 deletions

View File

@@ -1,6 +1,3 @@
import Stripe from "stripe";
import { usageService } from "./usageService";
export enum FeatureId {
USERS = "users",
SITES = "sites",
@@ -135,25 +132,3 @@ export function getFeatureIdByPriceId(priceId: string): FeatureId | undefined {
return undefined;
}
export async function getLineItems(
featurePriceSet: FeaturePriceSet,
orgId: string,
): Promise<Stripe.Checkout.SessionCreateParams.LineItem[]> {
const users = await usageService.getUsage(orgId, FeatureId.USERS);
return Object.entries(featurePriceSet).map(([featureId, priceId]) => {
let quantity: number | undefined;
if (featureId === FeatureId.USERS) {
quantity = users?.instantaneousValue || 1;
} else if (featureId === FeatureId.TIER1) {
quantity = 1;
}
return {
price: priceId,
quantity: quantity
};
});
}

View File

@@ -0,0 +1,25 @@
import Stripe from "stripe";
import { FeatureId, FeaturePriceSet } from "./features";
import { usageService } from "./usageService";
export async function getLineItems(
featurePriceSet: FeaturePriceSet,
orgId: string,
): Promise<Stripe.Checkout.SessionCreateParams.LineItem[]> {
const users = await usageService.getUsage(orgId, FeatureId.USERS);
return Object.entries(featurePriceSet).map(([featureId, priceId]) => {
let quantity: number | undefined;
if (featureId === FeatureId.USERS) {
quantity = users?.instantaneousValue || 1;
} else if (featureId === FeatureId.TIER1) {
quantity = 1;
}
return {
price: priceId,
quantity: quantity
};
});
}

View File

@@ -29,3 +29,4 @@ export * from "./verifyUserIsOrgOwner";
export * from "./verifySiteResourceAccess";
export * from "./logActionAudit";
export * from "./verifyOlmAccess";
export * from "./verifyLimits";

View File

@@ -4,7 +4,6 @@ import { apiKeyOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
export async function verifyApiKeyOrgAccess(
req: Request,

View File

@@ -0,0 +1,47 @@
import { Request, Response, NextFunction } from "express";
import { db, orgs } from "@server/db";
import { userOrgs } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { usageService } from "@server/lib/billing/usageService";
import { build } from "@server/build";
export async function verifyLimits(
req: Request,
res: Response,
next: NextFunction
) {
if (build != "saas") {
return next();
}
const orgId = req.userOrgId || req.params.orgId;
if (!orgId) {
return next(); // its fine if we silently fail here because this is not critical to operation or security and its better user experience if we dont fail
}
try {
const reject = await usageService.checkLimitSet(orgId);
if (reject) {
return next(
createHttpError(
HttpCode.PAYMENT_REQUIRED,
"Organization has exceeded its usage limits. Please upgrade your plan or contact support."
)
);
}
return next();
} catch (e) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error checking limits"
)
);
}
}

View File

@@ -25,10 +25,10 @@ import {
getHomeLabFeaturePriceSet,
getScaleFeaturePriceSet,
getStarterFeaturePriceSet,
getLineItems,
FeatureId,
type FeaturePriceSet
} from "@server/lib/billing";
import { getLineItems } from "@server/lib/billing/getLineItems";
const changeTierSchema = z.strictObject({
orgId: z.string()
@@ -151,8 +151,10 @@ export async function changeTier(
// tier1 uses TIER1 product, tier2/tier3 use USERS product
const currentTier = subscription.type;
const switchingProducts =
(currentTier === "tier1" && (tier === "tier2" || tier === "tier3")) ||
((currentTier === "tier2" || currentTier === "tier3") && tier === "tier1");
(currentTier === "tier1" &&
(tier === "tier2" || tier === "tier3")) ||
((currentTier === "tier2" || currentTier === "tier3") &&
tier === "tier1");
let updatedSubscription;

View File

@@ -22,8 +22,12 @@ import logger from "@server/logger";
import config from "@server/lib/config";
import { fromError } from "zod-validation-error";
import stripe from "#private/lib/stripe";
import { getHomeLabFeaturePriceSet, getLineItems, getScaleFeaturePriceSet, getStarterFeaturePriceSet } from "@server/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import {
getHomeLabFeaturePriceSet,
getScaleFeaturePriceSet,
getStarterFeaturePriceSet
} from "@server/lib/billing";
import { getLineItems } from "@server/lib/billing/getLineItems";
import Stripe from "stripe";
const createCheckoutSessionSchema = z.strictObject({
@@ -31,7 +35,7 @@ const createCheckoutSessionSchema = z.strictObject({
});
const createCheckoutSessionBodySchema = z.strictObject({
tier: z.enum(["tier1", "tier2", "tier3"]),
tier: z.enum(["tier1", "tier2", "tier3"])
});
export async function createCheckoutSession(
@@ -90,12 +94,10 @@ export async function createCheckoutSession(
} else if (tier === "tier3") {
lineItems = await getLineItems(getScaleFeaturePriceSet(), orgId);
} else {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Invalid plan")
);
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid plan"));
}
logger.debug(`Line items: ${JSON.stringify(lineItems)}`)
logger.debug(`Line items: ${JSON.stringify(lineItems)}`);
const session = await stripe!.checkout.sessions.create({
client_reference_id: orgId, // So we can look it up the org later on the webhook

View File

@@ -41,7 +41,8 @@ import {
verifyUserHasAction,
verifyUserIsOrgOwner,
verifySiteResourceAccess,
verifyOlmAccess
verifyOlmAccess,
verifyLimits
} from "@server/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import rateLimit, { ipKeyGenerator } from "express-rate-limit";
@@ -79,6 +80,7 @@ authenticated.get(
authenticated.post(
"/org/:orgId",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateOrg),
logActionAudit(ActionsEnum.updateOrg),
org.updateOrg
@@ -161,6 +163,7 @@ authenticated.get(
authenticated.put(
"/org/:orgId/client",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createClient),
logActionAudit(ActionsEnum.createClient),
client.createClient
@@ -178,6 +181,7 @@ authenticated.delete(
authenticated.post(
"/client/:clientId/archive",
verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.archiveClient),
logActionAudit(ActionsEnum.archiveClient),
client.archiveClient
@@ -186,6 +190,7 @@ authenticated.post(
authenticated.post(
"/client/:clientId/unarchive",
verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.unarchiveClient),
logActionAudit(ActionsEnum.unarchiveClient),
client.unarchiveClient
@@ -194,6 +199,7 @@ authenticated.post(
authenticated.post(
"/client/:clientId/block",
verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.blockClient),
logActionAudit(ActionsEnum.blockClient),
client.blockClient
@@ -202,6 +208,7 @@ authenticated.post(
authenticated.post(
"/client/:clientId/unblock",
verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.unblockClient),
logActionAudit(ActionsEnum.unblockClient),
client.unblockClient
@@ -210,6 +217,7 @@ authenticated.post(
authenticated.post(
"/client/:clientId",
verifyClientAccess, // this will check if the user has access to the client
verifyLimits,
verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client
logActionAudit(ActionsEnum.updateClient),
client.updateClient
@@ -224,6 +232,7 @@ authenticated.post(
authenticated.post(
"/site/:siteId",
verifySiteAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateSite),
logActionAudit(ActionsEnum.updateSite),
site.updateSite
@@ -273,6 +282,7 @@ authenticated.get(
authenticated.put(
"/org/:orgId/site-resource",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createSiteResource),
logActionAudit(ActionsEnum.createSiteResource),
siteResource.createSiteResource
@@ -303,6 +313,7 @@ authenticated.get(
authenticated.post(
"/site-resource/:siteResourceId",
verifySiteResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateSiteResource),
logActionAudit(ActionsEnum.updateSiteResource),
siteResource.updateSiteResource
@@ -341,6 +352,7 @@ authenticated.post(
"/site-resource/:siteResourceId/roles",
verifySiteResourceAccess,
verifyRoleAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles),
siteResource.setSiteResourceRoles
@@ -350,6 +362,7 @@ authenticated.post(
"/site-resource/:siteResourceId/users",
verifySiteResourceAccess,
verifySetResourceUsers,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceUsers
@@ -359,6 +372,7 @@ authenticated.post(
"/site-resource/:siteResourceId/clients",
verifySiteResourceAccess,
verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceClients
@@ -368,6 +382,7 @@ authenticated.post(
"/site-resource/:siteResourceId/clients/add",
verifySiteResourceAccess,
verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.addClientToSiteResource
@@ -377,6 +392,7 @@ authenticated.post(
"/site-resource/:siteResourceId/clients/remove",
verifySiteResourceAccess,
verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.removeClientFromSiteResource
@@ -385,6 +401,7 @@ authenticated.post(
authenticated.put(
"/org/:orgId/resource",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createResource),
logActionAudit(ActionsEnum.createResource),
resource.createResource
@@ -499,6 +516,7 @@ authenticated.get(
authenticated.post(
"/resource/:resourceId",
verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateResource),
logActionAudit(ActionsEnum.updateResource),
resource.updateResource
@@ -514,6 +532,7 @@ authenticated.delete(
authenticated.put(
"/resource/:resourceId/target",
verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createTarget),
logActionAudit(ActionsEnum.createTarget),
target.createTarget
@@ -528,6 +547,7 @@ authenticated.get(
authenticated.put(
"/resource/:resourceId/rule",
verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createResourceRule),
logActionAudit(ActionsEnum.createResourceRule),
resource.createResourceRule
@@ -577,6 +597,7 @@ authenticated.delete(
authenticated.put(
"/org/:orgId/role",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createRole),
logActionAudit(ActionsEnum.createRole),
role.createRole
@@ -774,6 +795,7 @@ authenticated.delete(
authenticated.put(
"/org/:orgId/user",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createOrgUser),
logActionAudit(ActionsEnum.createOrgUser),
user.createOrgUser
@@ -985,6 +1007,7 @@ authenticated.get(
authenticated.put(
`/org/:orgId/api-key`,
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createApiKey),
logActionAudit(ActionsEnum.createApiKey),
apiKeys.createOrgApiKey
@@ -1010,6 +1033,7 @@ authenticated.get(
authenticated.put(
`/org/:orgId/domain`,
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createOrgDomain),
logActionAudit(ActionsEnum.createOrgDomain),
domain.createOrgDomain
@@ -1065,6 +1089,7 @@ authenticated.get(
authenticated.put(
"/org/:orgId/blueprint",
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.applyBlueprint),
blueprints.applyYAMLBlueprint
);

View File

@@ -146,7 +146,7 @@ authenticated.get(
);
// Site Resource endpoints
authenticated.put(
"/org/:orgId/private-resource",
"/org/:orgId/site-resource",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.createSiteResource),
logActionAudit(ActionsEnum.createSiteResource),

View File

@@ -42,7 +42,11 @@ import {
import { useTranslations } from "use-intl";
import Link from "next/link";
import { Tier } from "@server/types/Tiers";
import { tier1LimitSet, tier2LimitSet, tier3LimitSet } from "@server/lib/billing/limitSet";
import {
tier1LimitSet,
tier2LimitSet,
tier3LimitSet
} from "@server/lib/billing/limitSet";
import { FeatureId } from "@server/lib/billing/features";
// Plan tier definitions matching the mockup
@@ -93,7 +97,10 @@ const planOptions: PlanOption[] = [
];
// Tier limits mapping derived from limit sets
const tierLimits: Record<Tier, { users: number; sites: number; domains: number; remoteNodes: number }> = {
const tierLimits: Record<
Tier,
{ users: number; sites: number; domains: number; remoteNodes: number }
> = {
tier1: {
users: tier1LimitSet[FeatureId.USERS]?.value ?? 0,
sites: tier1LimitSet[FeatureId.SITES]?.value ?? 0,
@@ -151,8 +158,8 @@ export default function BillingPage() {
const [currentTier, setCurrentTier] = useState<Tier | null>(null);
// Usage IDs
const SITES = "sites";
const USERS = "users";
const SITES = "sites";
const DOMAINS = "domains";
const REMOTE_EXIT_NODES = "remoteExitNodes";
@@ -297,8 +304,70 @@ export default function BillingPage() {
await api.post(`/org/${org.org.orgId}/billing/change-tier`, {
tier
});
// Refresh subscription data
window.location.reload();
// Poll the API to check if the tier change has been reflected
const pollForTierChange = async (targetTier: Tier) => {
const maxAttempts = 30; // 30 seconds with 1 second interval
let attempts = 0;
const poll = async (): Promise<boolean> => {
try {
const res = await api.get<
AxiosResponse<GetOrgSubscriptionResponse>
>(`/org/${org.org.orgId}/billing/subscriptions`);
const { subscriptions } = res.data.data;
// Find tier subscription
const tierSub = subscriptions.find(
({ subscription }) =>
subscription?.type === "tier1" ||
subscription?.type === "tier2" ||
subscription?.type === "tier3"
);
// Check if the tier has changed to the target tier
if (tierSub?.subscription?.type === targetTier) {
return true;
}
return false;
} catch (error) {
console.error("Error polling subscription:", error);
return false;
}
};
while (attempts < maxAttempts) {
const success = await poll();
if (success) {
// Tier change reflected, refresh the page
window.location.reload();
return;
}
attempts++;
if (attempts < maxAttempts) {
// Wait 1 second before next poll
await new Promise((resolve) =>
setTimeout(resolve, 1000)
);
}
}
// If we've exhausted all attempts, show an error
toast({
title: "Tier change processing",
description:
"Your tier change is taking longer than expected. Please refresh the page in a moment to see the changes.",
variant: "destructive"
});
setIsLoading(false);
};
// Start polling for the tier change
pollForTierChange(tier);
} catch (error) {
toast({
title: "Failed to change tier",
@@ -323,8 +392,8 @@ export default function BillingPage() {
}
}
setShowConfirmDialog(false);
setPendingTier(null);
// setShowConfirmDialog(false);
// setPendingTier(null);
};
const showTierConfirmation = (
@@ -338,7 +407,7 @@ export default function BillingPage() {
};
const handleContactUs = () => {
window.open("mailto:sales@pangolin.net", "_blank");
window.open("https://pangolin.net/talk-to-us", "_blank");
};
// Get current plan ID from tier
@@ -393,7 +462,7 @@ export default function BillingPage() {
plan.tierType,
"downgrade",
plan.name,
plan.price + (plan.priceDetail || "")
plan.price + (" " + plan.priceDetail || "")
);
} else {
handleModifySubscription();
@@ -412,7 +481,7 @@ export default function BillingPage() {
plan.tierType,
"upgrade",
plan.name,
plan.price + (plan.priceDetail || "")
plan.price + (" " + plan.priceDetail || "")
);
} else {
handleModifySubscription();
@@ -438,17 +507,15 @@ export default function BillingPage() {
// Calculate current usage cost for display
const getUserCount = () => getUsageValue(USERS);
const getPricePerUser = () => {
console.log(
"Calculating price per user, tierSubscription:",
tierSubscription
);
if (!tierSubscription?.items) return 0;
// Find the subscription item for USERS feature
const usersItem = tierSubscription.items.find(
(item) => item.planId === USERS
(item) => item.featureId === USERS
);
console.log("Users subscription item:", usersItem);
// unitAmount is in cents, convert to dollars
if (usersItem?.unitAmount) {
return usersItem.unitAmount / 100;
@@ -529,6 +596,7 @@ export default function BillingPage() {
disabled={
isLoading || planAction.disabled
}
loading={isLoading && isCurrentPlan}
>
{planAction.label}
</Button>
@@ -736,9 +804,9 @@ export default function BillingPage() {
<span>
{
tierLimits[pendingTier.tier]
.sites
.users
}{" "}
{t("billingSites") || "Sites"}
{t("billingUsers") || "Users"}
</span>
</div>
<div className="flex items-center gap-2">
@@ -746,9 +814,9 @@ export default function BillingPage() {
<span>
{
tierLimits[pendingTier.tier]
.users
.sites
}{" "}
{t("billingUsers") || "Users"}
{t("billingSites") || "Sites"}
</span>
</div>
<div className="flex items-center gap-2">
@@ -787,14 +855,13 @@ export default function BillingPage() {
<Button
onClick={confirmTierChange}
disabled={isLoading}
loading={isLoading}
>
{isLoading
? t("billingProcessing") || "Processing..."
: pendingTier?.action === "upgrade"
? t("billingConfirmUpgradeButton") ||
"Confirm Upgrade"
: t("billingConfirmDowngradeButton") ||
"Confirm Downgrade"}
{pendingTier?.action === "upgrade"
? t("billingConfirmUpgradeButton") ||
"Confirm Upgrade"
: t("billingConfirmDowngradeButton") ||
"Confirm Downgrade"}
</Button>
</CredenzaFooter>
</CredenzaContent>