Adding limit checks

This commit is contained in:
Owen
2026-02-10 10:05:37 -08:00
committed by Owen Schwartz
parent c33f782eaf
commit 84d1cf7a1d
10 changed files with 207 additions and 64 deletions

View File

@@ -1,6 +1,3 @@
import Stripe from "stripe";
import { usageService } from "./usageService";
export enum FeatureId { export enum FeatureId {
USERS = "users", USERS = "users",
SITES = "sites", SITES = "sites",
@@ -135,25 +132,3 @@ export function getFeatureIdByPriceId(priceId: string): FeatureId | undefined {
return undefined; return undefined;
} }
export async function getLineItems(
featurePriceSet: FeaturePriceSet,
orgId: string,
): Promise<Stripe.Checkout.SessionCreateParams.LineItem[]> {
const users = await usageService.getUsage(orgId, FeatureId.USERS);
return Object.entries(featurePriceSet).map(([featureId, priceId]) => {
let quantity: number | undefined;
if (featureId === FeatureId.USERS) {
quantity = users?.instantaneousValue || 1;
} else if (featureId === FeatureId.TIER1) {
quantity = 1;
}
return {
price: priceId,
quantity: quantity
};
});
}

View File

@@ -0,0 +1,25 @@
import Stripe from "stripe";
import { FeatureId, FeaturePriceSet } from "./features";
import { usageService } from "./usageService";
export async function getLineItems(
featurePriceSet: FeaturePriceSet,
orgId: string,
): Promise<Stripe.Checkout.SessionCreateParams.LineItem[]> {
const users = await usageService.getUsage(orgId, FeatureId.USERS);
return Object.entries(featurePriceSet).map(([featureId, priceId]) => {
let quantity: number | undefined;
if (featureId === FeatureId.USERS) {
quantity = users?.instantaneousValue || 1;
} else if (featureId === FeatureId.TIER1) {
quantity = 1;
}
return {
price: priceId,
quantity: quantity
};
});
}

View File

@@ -29,3 +29,4 @@ export * from "./verifyUserIsOrgOwner";
export * from "./verifySiteResourceAccess"; export * from "./verifySiteResourceAccess";
export * from "./logActionAudit"; export * from "./logActionAudit";
export * from "./verifyOlmAccess"; export * from "./verifyOlmAccess";
export * from "./verifyLimits";

View File

@@ -4,7 +4,6 @@ import { apiKeyOrg } from "@server/db";
import { and, eq } from "drizzle-orm"; import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import logger from "@server/logger";
export async function verifyApiKeyOrgAccess( export async function verifyApiKeyOrgAccess(
req: Request, req: Request,

View File

@@ -0,0 +1,47 @@
import { Request, Response, NextFunction } from "express";
import { db, orgs } from "@server/db";
import { userOrgs } from "@server/db";
import { and, eq } from "drizzle-orm";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { usageService } from "@server/lib/billing/usageService";
import { build } from "@server/build";
export async function verifyLimits(
req: Request,
res: Response,
next: NextFunction
) {
if (build != "saas") {
return next();
}
const orgId = req.userOrgId || req.params.orgId;
if (!orgId) {
return next(); // its fine if we silently fail here because this is not critical to operation or security and its better user experience if we dont fail
}
try {
const reject = await usageService.checkLimitSet(orgId);
if (reject) {
return next(
createHttpError(
HttpCode.PAYMENT_REQUIRED,
"Organization has exceeded its usage limits. Please upgrade your plan or contact support."
)
);
}
return next();
} catch (e) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error checking limits"
)
);
}
}

View File

@@ -25,10 +25,10 @@ import {
getHomeLabFeaturePriceSet, getHomeLabFeaturePriceSet,
getScaleFeaturePriceSet, getScaleFeaturePriceSet,
getStarterFeaturePriceSet, getStarterFeaturePriceSet,
getLineItems,
FeatureId, FeatureId,
type FeaturePriceSet type FeaturePriceSet
} from "@server/lib/billing"; } from "@server/lib/billing";
import { getLineItems } from "@server/lib/billing/getLineItems";
const changeTierSchema = z.strictObject({ const changeTierSchema = z.strictObject({
orgId: z.string() orgId: z.string()
@@ -151,8 +151,10 @@ export async function changeTier(
// tier1 uses TIER1 product, tier2/tier3 use USERS product // tier1 uses TIER1 product, tier2/tier3 use USERS product
const currentTier = subscription.type; const currentTier = subscription.type;
const switchingProducts = const switchingProducts =
(currentTier === "tier1" && (tier === "tier2" || tier === "tier3")) || (currentTier === "tier1" &&
((currentTier === "tier2" || currentTier === "tier3") && tier === "tier1"); (tier === "tier2" || tier === "tier3")) ||
((currentTier === "tier2" || currentTier === "tier3") &&
tier === "tier1");
let updatedSubscription; let updatedSubscription;

View File

@@ -22,8 +22,12 @@ import logger from "@server/logger";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import stripe from "#private/lib/stripe"; import stripe from "#private/lib/stripe";
import { getHomeLabFeaturePriceSet, getLineItems, getScaleFeaturePriceSet, getStarterFeaturePriceSet } from "@server/lib/billing"; import {
import { usageService } from "@server/lib/billing/usageService"; getHomeLabFeaturePriceSet,
getScaleFeaturePriceSet,
getStarterFeaturePriceSet
} from "@server/lib/billing";
import { getLineItems } from "@server/lib/billing/getLineItems";
import Stripe from "stripe"; import Stripe from "stripe";
const createCheckoutSessionSchema = z.strictObject({ const createCheckoutSessionSchema = z.strictObject({
@@ -31,7 +35,7 @@ const createCheckoutSessionSchema = z.strictObject({
}); });
const createCheckoutSessionBodySchema = z.strictObject({ const createCheckoutSessionBodySchema = z.strictObject({
tier: z.enum(["tier1", "tier2", "tier3"]), tier: z.enum(["tier1", "tier2", "tier3"])
}); });
export async function createCheckoutSession( export async function createCheckoutSession(
@@ -90,12 +94,10 @@ export async function createCheckoutSession(
} else if (tier === "tier3") { } else if (tier === "tier3") {
lineItems = await getLineItems(getScaleFeaturePriceSet(), orgId); lineItems = await getLineItems(getScaleFeaturePriceSet(), orgId);
} else { } else {
return next( return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid plan"));
createHttpError(HttpCode.BAD_REQUEST, "Invalid plan")
);
} }
logger.debug(`Line items: ${JSON.stringify(lineItems)}`) logger.debug(`Line items: ${JSON.stringify(lineItems)}`);
const session = await stripe!.checkout.sessions.create({ const session = await stripe!.checkout.sessions.create({
client_reference_id: orgId, // So we can look it up the org later on the webhook client_reference_id: orgId, // So we can look it up the org later on the webhook

View File

@@ -41,7 +41,8 @@ import {
verifyUserHasAction, verifyUserHasAction,
verifyUserIsOrgOwner, verifyUserIsOrgOwner,
verifySiteResourceAccess, verifySiteResourceAccess,
verifyOlmAccess verifyOlmAccess,
verifyLimits
} from "@server/middlewares"; } from "@server/middlewares";
import { ActionsEnum } from "@server/auth/actions"; import { ActionsEnum } from "@server/auth/actions";
import rateLimit, { ipKeyGenerator } from "express-rate-limit"; import rateLimit, { ipKeyGenerator } from "express-rate-limit";
@@ -79,6 +80,7 @@ authenticated.get(
authenticated.post( authenticated.post(
"/org/:orgId", "/org/:orgId",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateOrg), verifyUserHasAction(ActionsEnum.updateOrg),
logActionAudit(ActionsEnum.updateOrg), logActionAudit(ActionsEnum.updateOrg),
org.updateOrg org.updateOrg
@@ -161,6 +163,7 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/client", "/org/:orgId/client",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createClient), verifyUserHasAction(ActionsEnum.createClient),
logActionAudit(ActionsEnum.createClient), logActionAudit(ActionsEnum.createClient),
client.createClient client.createClient
@@ -178,6 +181,7 @@ authenticated.delete(
authenticated.post( authenticated.post(
"/client/:clientId/archive", "/client/:clientId/archive",
verifyClientAccess, verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.archiveClient), verifyUserHasAction(ActionsEnum.archiveClient),
logActionAudit(ActionsEnum.archiveClient), logActionAudit(ActionsEnum.archiveClient),
client.archiveClient client.archiveClient
@@ -186,6 +190,7 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId/unarchive", "/client/:clientId/unarchive",
verifyClientAccess, verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.unarchiveClient), verifyUserHasAction(ActionsEnum.unarchiveClient),
logActionAudit(ActionsEnum.unarchiveClient), logActionAudit(ActionsEnum.unarchiveClient),
client.unarchiveClient client.unarchiveClient
@@ -194,6 +199,7 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId/block", "/client/:clientId/block",
verifyClientAccess, verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.blockClient), verifyUserHasAction(ActionsEnum.blockClient),
logActionAudit(ActionsEnum.blockClient), logActionAudit(ActionsEnum.blockClient),
client.blockClient client.blockClient
@@ -202,6 +208,7 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId/unblock", "/client/:clientId/unblock",
verifyClientAccess, verifyClientAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.unblockClient), verifyUserHasAction(ActionsEnum.unblockClient),
logActionAudit(ActionsEnum.unblockClient), logActionAudit(ActionsEnum.unblockClient),
client.unblockClient client.unblockClient
@@ -210,6 +217,7 @@ authenticated.post(
authenticated.post( authenticated.post(
"/client/:clientId", "/client/:clientId",
verifyClientAccess, // this will check if the user has access to the client verifyClientAccess, // this will check if the user has access to the client
verifyLimits,
verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client
logActionAudit(ActionsEnum.updateClient), logActionAudit(ActionsEnum.updateClient),
client.updateClient client.updateClient
@@ -224,6 +232,7 @@ authenticated.post(
authenticated.post( authenticated.post(
"/site/:siteId", "/site/:siteId",
verifySiteAccess, verifySiteAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateSite), verifyUserHasAction(ActionsEnum.updateSite),
logActionAudit(ActionsEnum.updateSite), logActionAudit(ActionsEnum.updateSite),
site.updateSite site.updateSite
@@ -273,6 +282,7 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/site-resource", "/org/:orgId/site-resource",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createSiteResource), verifyUserHasAction(ActionsEnum.createSiteResource),
logActionAudit(ActionsEnum.createSiteResource), logActionAudit(ActionsEnum.createSiteResource),
siteResource.createSiteResource siteResource.createSiteResource
@@ -303,6 +313,7 @@ authenticated.get(
authenticated.post( authenticated.post(
"/site-resource/:siteResourceId", "/site-resource/:siteResourceId",
verifySiteResourceAccess, verifySiteResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateSiteResource), verifyUserHasAction(ActionsEnum.updateSiteResource),
logActionAudit(ActionsEnum.updateSiteResource), logActionAudit(ActionsEnum.updateSiteResource),
siteResource.updateSiteResource siteResource.updateSiteResource
@@ -341,6 +352,7 @@ authenticated.post(
"/site-resource/:siteResourceId/roles", "/site-resource/:siteResourceId/roles",
verifySiteResourceAccess, verifySiteResourceAccess,
verifyRoleAccess, verifyRoleAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceRoles), verifyUserHasAction(ActionsEnum.setResourceRoles),
logActionAudit(ActionsEnum.setResourceRoles), logActionAudit(ActionsEnum.setResourceRoles),
siteResource.setSiteResourceRoles siteResource.setSiteResourceRoles
@@ -350,6 +362,7 @@ authenticated.post(
"/site-resource/:siteResourceId/users", "/site-resource/:siteResourceId/users",
verifySiteResourceAccess, verifySiteResourceAccess,
verifySetResourceUsers, verifySetResourceUsers,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers), verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceUsers siteResource.setSiteResourceUsers
@@ -359,6 +372,7 @@ authenticated.post(
"/site-resource/:siteResourceId/clients", "/site-resource/:siteResourceId/clients",
verifySiteResourceAccess, verifySiteResourceAccess,
verifySetResourceClients, verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers), verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.setSiteResourceClients siteResource.setSiteResourceClients
@@ -368,6 +382,7 @@ authenticated.post(
"/site-resource/:siteResourceId/clients/add", "/site-resource/:siteResourceId/clients/add",
verifySiteResourceAccess, verifySiteResourceAccess,
verifySetResourceClients, verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers), verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.addClientToSiteResource siteResource.addClientToSiteResource
@@ -377,6 +392,7 @@ authenticated.post(
"/site-resource/:siteResourceId/clients/remove", "/site-resource/:siteResourceId/clients/remove",
verifySiteResourceAccess, verifySiteResourceAccess,
verifySetResourceClients, verifySetResourceClients,
verifyLimits,
verifyUserHasAction(ActionsEnum.setResourceUsers), verifyUserHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers),
siteResource.removeClientFromSiteResource siteResource.removeClientFromSiteResource
@@ -385,6 +401,7 @@ authenticated.post(
authenticated.put( authenticated.put(
"/org/:orgId/resource", "/org/:orgId/resource",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createResource), verifyUserHasAction(ActionsEnum.createResource),
logActionAudit(ActionsEnum.createResource), logActionAudit(ActionsEnum.createResource),
resource.createResource resource.createResource
@@ -499,6 +516,7 @@ authenticated.get(
authenticated.post( authenticated.post(
"/resource/:resourceId", "/resource/:resourceId",
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateResource), verifyUserHasAction(ActionsEnum.updateResource),
logActionAudit(ActionsEnum.updateResource), logActionAudit(ActionsEnum.updateResource),
resource.updateResource resource.updateResource
@@ -514,6 +532,7 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/resource/:resourceId/target", "/resource/:resourceId/target",
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createTarget), verifyUserHasAction(ActionsEnum.createTarget),
logActionAudit(ActionsEnum.createTarget), logActionAudit(ActionsEnum.createTarget),
target.createTarget target.createTarget
@@ -528,6 +547,7 @@ authenticated.get(
authenticated.put( authenticated.put(
"/resource/:resourceId/rule", "/resource/:resourceId/rule",
verifyResourceAccess, verifyResourceAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createResourceRule), verifyUserHasAction(ActionsEnum.createResourceRule),
logActionAudit(ActionsEnum.createResourceRule), logActionAudit(ActionsEnum.createResourceRule),
resource.createResourceRule resource.createResourceRule
@@ -577,6 +597,7 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/org/:orgId/role", "/org/:orgId/role",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createRole), verifyUserHasAction(ActionsEnum.createRole),
logActionAudit(ActionsEnum.createRole), logActionAudit(ActionsEnum.createRole),
role.createRole role.createRole
@@ -774,6 +795,7 @@ authenticated.delete(
authenticated.put( authenticated.put(
"/org/:orgId/user", "/org/:orgId/user",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createOrgUser), verifyUserHasAction(ActionsEnum.createOrgUser),
logActionAudit(ActionsEnum.createOrgUser), logActionAudit(ActionsEnum.createOrgUser),
user.createOrgUser user.createOrgUser
@@ -985,6 +1007,7 @@ authenticated.get(
authenticated.put( authenticated.put(
`/org/:orgId/api-key`, `/org/:orgId/api-key`,
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createApiKey), verifyUserHasAction(ActionsEnum.createApiKey),
logActionAudit(ActionsEnum.createApiKey), logActionAudit(ActionsEnum.createApiKey),
apiKeys.createOrgApiKey apiKeys.createOrgApiKey
@@ -1010,6 +1033,7 @@ authenticated.get(
authenticated.put( authenticated.put(
`/org/:orgId/domain`, `/org/:orgId/domain`,
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createOrgDomain), verifyUserHasAction(ActionsEnum.createOrgDomain),
logActionAudit(ActionsEnum.createOrgDomain), logActionAudit(ActionsEnum.createOrgDomain),
domain.createOrgDomain domain.createOrgDomain
@@ -1065,6 +1089,7 @@ authenticated.get(
authenticated.put( authenticated.put(
"/org/:orgId/blueprint", "/org/:orgId/blueprint",
verifyOrgAccess, verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.applyBlueprint), verifyUserHasAction(ActionsEnum.applyBlueprint),
blueprints.applyYAMLBlueprint blueprints.applyYAMLBlueprint
); );

View File

@@ -146,7 +146,7 @@ authenticated.get(
); );
// Site Resource endpoints // Site Resource endpoints
authenticated.put( authenticated.put(
"/org/:orgId/private-resource", "/org/:orgId/site-resource",
verifyApiKeyOrgAccess, verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.createSiteResource), verifyApiKeyHasAction(ActionsEnum.createSiteResource),
logActionAudit(ActionsEnum.createSiteResource), logActionAudit(ActionsEnum.createSiteResource),

View File

@@ -42,7 +42,11 @@ import {
import { useTranslations } from "use-intl"; import { useTranslations } from "use-intl";
import Link from "next/link"; import Link from "next/link";
import { Tier } from "@server/types/Tiers"; import { Tier } from "@server/types/Tiers";
import { tier1LimitSet, tier2LimitSet, tier3LimitSet } from "@server/lib/billing/limitSet"; import {
tier1LimitSet,
tier2LimitSet,
tier3LimitSet
} from "@server/lib/billing/limitSet";
import { FeatureId } from "@server/lib/billing/features"; import { FeatureId } from "@server/lib/billing/features";
// Plan tier definitions matching the mockup // Plan tier definitions matching the mockup
@@ -93,7 +97,10 @@ const planOptions: PlanOption[] = [
]; ];
// Tier limits mapping derived from limit sets // Tier limits mapping derived from limit sets
const tierLimits: Record<Tier, { users: number; sites: number; domains: number; remoteNodes: number }> = { const tierLimits: Record<
Tier,
{ users: number; sites: number; domains: number; remoteNodes: number }
> = {
tier1: { tier1: {
users: tier1LimitSet[FeatureId.USERS]?.value ?? 0, users: tier1LimitSet[FeatureId.USERS]?.value ?? 0,
sites: tier1LimitSet[FeatureId.SITES]?.value ?? 0, sites: tier1LimitSet[FeatureId.SITES]?.value ?? 0,
@@ -151,8 +158,8 @@ export default function BillingPage() {
const [currentTier, setCurrentTier] = useState<Tier | null>(null); const [currentTier, setCurrentTier] = useState<Tier | null>(null);
// Usage IDs // Usage IDs
const SITES = "sites";
const USERS = "users"; const USERS = "users";
const SITES = "sites";
const DOMAINS = "domains"; const DOMAINS = "domains";
const REMOTE_EXIT_NODES = "remoteExitNodes"; const REMOTE_EXIT_NODES = "remoteExitNodes";
@@ -297,8 +304,70 @@ export default function BillingPage() {
await api.post(`/org/${org.org.orgId}/billing/change-tier`, { await api.post(`/org/${org.org.orgId}/billing/change-tier`, {
tier tier
}); });
// Refresh subscription data
window.location.reload(); // Poll the API to check if the tier change has been reflected
const pollForTierChange = async (targetTier: Tier) => {
const maxAttempts = 30; // 30 seconds with 1 second interval
let attempts = 0;
const poll = async (): Promise<boolean> => {
try {
const res = await api.get<
AxiosResponse<GetOrgSubscriptionResponse>
>(`/org/${org.org.orgId}/billing/subscriptions`);
const { subscriptions } = res.data.data;
// Find tier subscription
const tierSub = subscriptions.find(
({ subscription }) =>
subscription?.type === "tier1" ||
subscription?.type === "tier2" ||
subscription?.type === "tier3"
);
// Check if the tier has changed to the target tier
if (tierSub?.subscription?.type === targetTier) {
return true;
}
return false;
} catch (error) {
console.error("Error polling subscription:", error);
return false;
}
};
while (attempts < maxAttempts) {
const success = await poll();
if (success) {
// Tier change reflected, refresh the page
window.location.reload();
return;
}
attempts++;
if (attempts < maxAttempts) {
// Wait 1 second before next poll
await new Promise((resolve) =>
setTimeout(resolve, 1000)
);
}
}
// If we've exhausted all attempts, show an error
toast({
title: "Tier change processing",
description:
"Your tier change is taking longer than expected. Please refresh the page in a moment to see the changes.",
variant: "destructive"
});
setIsLoading(false);
};
// Start polling for the tier change
pollForTierChange(tier);
} catch (error) { } catch (error) {
toast({ toast({
title: "Failed to change tier", title: "Failed to change tier",
@@ -323,8 +392,8 @@ export default function BillingPage() {
} }
} }
setShowConfirmDialog(false); // setShowConfirmDialog(false);
setPendingTier(null); // setPendingTier(null);
}; };
const showTierConfirmation = ( const showTierConfirmation = (
@@ -338,7 +407,7 @@ export default function BillingPage() {
}; };
const handleContactUs = () => { const handleContactUs = () => {
window.open("mailto:sales@pangolin.net", "_blank"); window.open("https://pangolin.net/talk-to-us", "_blank");
}; };
// Get current plan ID from tier // Get current plan ID from tier
@@ -393,7 +462,7 @@ export default function BillingPage() {
plan.tierType, plan.tierType,
"downgrade", "downgrade",
plan.name, plan.name,
plan.price + (plan.priceDetail || "") plan.price + (" " + plan.priceDetail || "")
); );
} else { } else {
handleModifySubscription(); handleModifySubscription();
@@ -412,7 +481,7 @@ export default function BillingPage() {
plan.tierType, plan.tierType,
"upgrade", "upgrade",
plan.name, plan.name,
plan.price + (plan.priceDetail || "") plan.price + (" " + plan.priceDetail || "")
); );
} else { } else {
handleModifySubscription(); handleModifySubscription();
@@ -438,17 +507,15 @@ export default function BillingPage() {
// Calculate current usage cost for display // Calculate current usage cost for display
const getUserCount = () => getUsageValue(USERS); const getUserCount = () => getUsageValue(USERS);
const getPricePerUser = () => { const getPricePerUser = () => {
console.log(
"Calculating price per user, tierSubscription:",
tierSubscription
);
if (!tierSubscription?.items) return 0; if (!tierSubscription?.items) return 0;
// Find the subscription item for USERS feature // Find the subscription item for USERS feature
const usersItem = tierSubscription.items.find( const usersItem = tierSubscription.items.find(
(item) => item.planId === USERS (item) => item.featureId === USERS
); );
console.log("Users subscription item:", usersItem);
// unitAmount is in cents, convert to dollars // unitAmount is in cents, convert to dollars
if (usersItem?.unitAmount) { if (usersItem?.unitAmount) {
return usersItem.unitAmount / 100; return usersItem.unitAmount / 100;
@@ -529,6 +596,7 @@ export default function BillingPage() {
disabled={ disabled={
isLoading || planAction.disabled isLoading || planAction.disabled
} }
loading={isLoading && isCurrentPlan}
> >
{planAction.label} {planAction.label}
</Button> </Button>
@@ -736,9 +804,9 @@ export default function BillingPage() {
<span> <span>
{ {
tierLimits[pendingTier.tier] tierLimits[pendingTier.tier]
.sites .users
}{" "} }{" "}
{t("billingSites") || "Sites"} {t("billingUsers") || "Users"}
</span> </span>
</div> </div>
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
@@ -746,9 +814,9 @@ export default function BillingPage() {
<span> <span>
{ {
tierLimits[pendingTier.tier] tierLimits[pendingTier.tier]
.users .sites
}{" "} }{" "}
{t("billingUsers") || "Users"} {t("billingSites") || "Sites"}
</span> </span>
</div> </div>
<div className="flex items-center gap-2"> <div className="flex items-center gap-2">
@@ -787,14 +855,13 @@ export default function BillingPage() {
<Button <Button
onClick={confirmTierChange} onClick={confirmTierChange}
disabled={isLoading} disabled={isLoading}
loading={isLoading}
> >
{isLoading {pendingTier?.action === "upgrade"
? t("billingProcessing") || "Processing..." ? t("billingConfirmUpgradeButton") ||
: pendingTier?.action === "upgrade" "Confirm Upgrade"
? t("billingConfirmUpgradeButton") || : t("billingConfirmDowngradeButton") ||
"Confirm Upgrade" "Confirm Downgrade"}
: t("billingConfirmDowngradeButton") ||
"Confirm Downgrade"}
</Button> </Button>
</CredenzaFooter> </CredenzaFooter>
</CredenzaContent> </CredenzaContent>