diff --git a/server/lib/billing/features.ts b/server/lib/billing/features.ts index 82ba0676..9eade18f 100644 --- a/server/lib/billing/features.ts +++ b/server/lib/billing/features.ts @@ -1,6 +1,3 @@ -import Stripe from "stripe"; -import { usageService } from "./usageService"; - export enum FeatureId { USERS = "users", SITES = "sites", @@ -135,25 +132,3 @@ export function getFeatureIdByPriceId(priceId: string): FeatureId | undefined { return undefined; } - -export async function getLineItems( - featurePriceSet: FeaturePriceSet, - orgId: string, -): Promise { - const users = await usageService.getUsage(orgId, FeatureId.USERS); - - return Object.entries(featurePriceSet).map(([featureId, priceId]) => { - let quantity: number | undefined; - - if (featureId === FeatureId.USERS) { - quantity = users?.instantaneousValue || 1; - } else if (featureId === FeatureId.TIER1) { - quantity = 1; - } - - return { - price: priceId, - quantity: quantity - }; - }); -} diff --git a/server/lib/billing/getLineItems.ts b/server/lib/billing/getLineItems.ts new file mode 100644 index 00000000..d386e5e9 --- /dev/null +++ b/server/lib/billing/getLineItems.ts @@ -0,0 +1,25 @@ +import Stripe from "stripe"; +import { FeatureId, FeaturePriceSet } from "./features"; +import { usageService } from "./usageService"; + +export async function getLineItems( + featurePriceSet: FeaturePriceSet, + orgId: string, +): Promise { + const users = await usageService.getUsage(orgId, FeatureId.USERS); + + return Object.entries(featurePriceSet).map(([featureId, priceId]) => { + let quantity: number | undefined; + + if (featureId === FeatureId.USERS) { + quantity = users?.instantaneousValue || 1; + } else if (featureId === FeatureId.TIER1) { + quantity = 1; + } + + return { + price: priceId, + quantity: quantity + }; + }); +} diff --git a/server/middlewares/index.ts b/server/middlewares/index.ts index 305abaa8..6437c90e 100644 --- a/server/middlewares/index.ts +++ b/server/middlewares/index.ts @@ -29,3 +29,4 @@ export * from "./verifyUserIsOrgOwner"; export * from "./verifySiteResourceAccess"; export * from "./logActionAudit"; export * from "./verifyOlmAccess"; +export * from "./verifyLimits"; diff --git a/server/middlewares/integration/verifyApiKeyOrgAccess.ts b/server/middlewares/integration/verifyApiKeyOrgAccess.ts index c705dc0f..97880003 100644 --- a/server/middlewares/integration/verifyApiKeyOrgAccess.ts +++ b/server/middlewares/integration/verifyApiKeyOrgAccess.ts @@ -4,7 +4,6 @@ import { apiKeyOrg } from "@server/db"; import { and, eq } from "drizzle-orm"; import createHttpError from "http-errors"; import HttpCode from "@server/types/HttpCode"; -import logger from "@server/logger"; export async function verifyApiKeyOrgAccess( req: Request, diff --git a/server/middlewares/verifyLimits.ts b/server/middlewares/verifyLimits.ts new file mode 100644 index 00000000..99330c33 --- /dev/null +++ b/server/middlewares/verifyLimits.ts @@ -0,0 +1,47 @@ +import { Request, Response, NextFunction } from "express"; +import { db, orgs } from "@server/db"; +import { userOrgs } from "@server/db"; +import { and, eq } from "drizzle-orm"; +import createHttpError from "http-errors"; +import HttpCode from "@server/types/HttpCode"; +import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy"; +import { usageService } from "@server/lib/billing/usageService"; +import { build } from "@server/build"; + +export async function verifyLimits( + req: Request, + res: Response, + next: NextFunction +) { + if (build != "saas") { + return next(); + } + + const orgId = req.userOrgId || req.params.orgId; + + if (!orgId) { + return next(); // its fine if we silently fail here because this is not critical to operation or security and its better user experience if we dont fail + } + + try { + const reject = await usageService.checkLimitSet(orgId); + + if (reject) { + return next( + createHttpError( + HttpCode.PAYMENT_REQUIRED, + "Organization has exceeded its usage limits. Please upgrade your plan or contact support." + ) + ); + } + + return next(); + } catch (e) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Error checking limits" + ) + ); + } +} diff --git a/server/private/routers/billing/changeTier.ts b/server/private/routers/billing/changeTier.ts index 5d67b7e8..ee60c0ec 100644 --- a/server/private/routers/billing/changeTier.ts +++ b/server/private/routers/billing/changeTier.ts @@ -25,10 +25,10 @@ import { getHomeLabFeaturePriceSet, getScaleFeaturePriceSet, getStarterFeaturePriceSet, - getLineItems, FeatureId, type FeaturePriceSet } from "@server/lib/billing"; +import { getLineItems } from "@server/lib/billing/getLineItems"; const changeTierSchema = z.strictObject({ orgId: z.string() @@ -151,8 +151,10 @@ export async function changeTier( // tier1 uses TIER1 product, tier2/tier3 use USERS product const currentTier = subscription.type; const switchingProducts = - (currentTier === "tier1" && (tier === "tier2" || tier === "tier3")) || - ((currentTier === "tier2" || currentTier === "tier3") && tier === "tier1"); + (currentTier === "tier1" && + (tier === "tier2" || tier === "tier3")) || + ((currentTier === "tier2" || currentTier === "tier3") && + tier === "tier1"); let updatedSubscription; diff --git a/server/private/routers/billing/createCheckoutSession.ts b/server/private/routers/billing/createCheckoutSession.ts index 67eaa37e..dc722519 100644 --- a/server/private/routers/billing/createCheckoutSession.ts +++ b/server/private/routers/billing/createCheckoutSession.ts @@ -22,8 +22,12 @@ import logger from "@server/logger"; import config from "@server/lib/config"; import { fromError } from "zod-validation-error"; import stripe from "#private/lib/stripe"; -import { getHomeLabFeaturePriceSet, getLineItems, getScaleFeaturePriceSet, getStarterFeaturePriceSet } from "@server/lib/billing"; -import { usageService } from "@server/lib/billing/usageService"; +import { + getHomeLabFeaturePriceSet, + getScaleFeaturePriceSet, + getStarterFeaturePriceSet +} from "@server/lib/billing"; +import { getLineItems } from "@server/lib/billing/getLineItems"; import Stripe from "stripe"; const createCheckoutSessionSchema = z.strictObject({ @@ -31,7 +35,7 @@ const createCheckoutSessionSchema = z.strictObject({ }); const createCheckoutSessionBodySchema = z.strictObject({ - tier: z.enum(["tier1", "tier2", "tier3"]), + tier: z.enum(["tier1", "tier2", "tier3"]) }); export async function createCheckoutSession( @@ -90,12 +94,10 @@ export async function createCheckoutSession( } else if (tier === "tier3") { lineItems = await getLineItems(getScaleFeaturePriceSet(), orgId); } else { - return next( - createHttpError(HttpCode.BAD_REQUEST, "Invalid plan") - ); + return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid plan")); } - logger.debug(`Line items: ${JSON.stringify(lineItems)}`) + logger.debug(`Line items: ${JSON.stringify(lineItems)}`); const session = await stripe!.checkout.sessions.create({ client_reference_id: orgId, // So we can look it up the org later on the webhook diff --git a/server/routers/external.ts b/server/routers/external.ts index aff01bfa..48768598 100644 --- a/server/routers/external.ts +++ b/server/routers/external.ts @@ -41,7 +41,8 @@ import { verifyUserHasAction, verifyUserIsOrgOwner, verifySiteResourceAccess, - verifyOlmAccess + verifyOlmAccess, + verifyLimits } from "@server/middlewares"; import { ActionsEnum } from "@server/auth/actions"; import rateLimit, { ipKeyGenerator } from "express-rate-limit"; @@ -79,6 +80,7 @@ authenticated.get( authenticated.post( "/org/:orgId", verifyOrgAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.updateOrg), logActionAudit(ActionsEnum.updateOrg), org.updateOrg @@ -161,6 +163,7 @@ authenticated.get( authenticated.put( "/org/:orgId/client", verifyOrgAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.createClient), logActionAudit(ActionsEnum.createClient), client.createClient @@ -178,6 +181,7 @@ authenticated.delete( authenticated.post( "/client/:clientId/archive", verifyClientAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.archiveClient), logActionAudit(ActionsEnum.archiveClient), client.archiveClient @@ -186,6 +190,7 @@ authenticated.post( authenticated.post( "/client/:clientId/unarchive", verifyClientAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.unarchiveClient), logActionAudit(ActionsEnum.unarchiveClient), client.unarchiveClient @@ -194,6 +199,7 @@ authenticated.post( authenticated.post( "/client/:clientId/block", verifyClientAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.blockClient), logActionAudit(ActionsEnum.blockClient), client.blockClient @@ -202,6 +208,7 @@ authenticated.post( authenticated.post( "/client/:clientId/unblock", verifyClientAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.unblockClient), logActionAudit(ActionsEnum.unblockClient), client.unblockClient @@ -210,6 +217,7 @@ authenticated.post( authenticated.post( "/client/:clientId", verifyClientAccess, // this will check if the user has access to the client + verifyLimits, verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client logActionAudit(ActionsEnum.updateClient), client.updateClient @@ -224,6 +232,7 @@ authenticated.post( authenticated.post( "/site/:siteId", verifySiteAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.updateSite), logActionAudit(ActionsEnum.updateSite), site.updateSite @@ -273,6 +282,7 @@ authenticated.get( authenticated.put( "/org/:orgId/site-resource", verifyOrgAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.createSiteResource), logActionAudit(ActionsEnum.createSiteResource), siteResource.createSiteResource @@ -303,6 +313,7 @@ authenticated.get( authenticated.post( "/site-resource/:siteResourceId", verifySiteResourceAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.updateSiteResource), logActionAudit(ActionsEnum.updateSiteResource), siteResource.updateSiteResource @@ -341,6 +352,7 @@ authenticated.post( "/site-resource/:siteResourceId/roles", verifySiteResourceAccess, verifyRoleAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.setResourceRoles), logActionAudit(ActionsEnum.setResourceRoles), siteResource.setSiteResourceRoles @@ -350,6 +362,7 @@ authenticated.post( "/site-resource/:siteResourceId/users", verifySiteResourceAccess, verifySetResourceUsers, + verifyLimits, verifyUserHasAction(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers), siteResource.setSiteResourceUsers @@ -359,6 +372,7 @@ authenticated.post( "/site-resource/:siteResourceId/clients", verifySiteResourceAccess, verifySetResourceClients, + verifyLimits, verifyUserHasAction(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers), siteResource.setSiteResourceClients @@ -368,6 +382,7 @@ authenticated.post( "/site-resource/:siteResourceId/clients/add", verifySiteResourceAccess, verifySetResourceClients, + verifyLimits, verifyUserHasAction(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers), siteResource.addClientToSiteResource @@ -377,6 +392,7 @@ authenticated.post( "/site-resource/:siteResourceId/clients/remove", verifySiteResourceAccess, verifySetResourceClients, + verifyLimits, verifyUserHasAction(ActionsEnum.setResourceUsers), logActionAudit(ActionsEnum.setResourceUsers), siteResource.removeClientFromSiteResource @@ -385,6 +401,7 @@ authenticated.post( authenticated.put( "/org/:orgId/resource", verifyOrgAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.createResource), logActionAudit(ActionsEnum.createResource), resource.createResource @@ -499,6 +516,7 @@ authenticated.get( authenticated.post( "/resource/:resourceId", verifyResourceAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.updateResource), logActionAudit(ActionsEnum.updateResource), resource.updateResource @@ -514,6 +532,7 @@ authenticated.delete( authenticated.put( "/resource/:resourceId/target", verifyResourceAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.createTarget), logActionAudit(ActionsEnum.createTarget), target.createTarget @@ -528,6 +547,7 @@ authenticated.get( authenticated.put( "/resource/:resourceId/rule", verifyResourceAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.createResourceRule), logActionAudit(ActionsEnum.createResourceRule), resource.createResourceRule @@ -577,6 +597,7 @@ authenticated.delete( authenticated.put( "/org/:orgId/role", verifyOrgAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.createRole), logActionAudit(ActionsEnum.createRole), role.createRole @@ -774,6 +795,7 @@ authenticated.delete( authenticated.put( "/org/:orgId/user", verifyOrgAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.createOrgUser), logActionAudit(ActionsEnum.createOrgUser), user.createOrgUser @@ -985,6 +1007,7 @@ authenticated.get( authenticated.put( `/org/:orgId/api-key`, verifyOrgAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.createApiKey), logActionAudit(ActionsEnum.createApiKey), apiKeys.createOrgApiKey @@ -1010,6 +1033,7 @@ authenticated.get( authenticated.put( `/org/:orgId/domain`, verifyOrgAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.createOrgDomain), logActionAudit(ActionsEnum.createOrgDomain), domain.createOrgDomain @@ -1065,6 +1089,7 @@ authenticated.get( authenticated.put( "/org/:orgId/blueprint", verifyOrgAccess, + verifyLimits, verifyUserHasAction(ActionsEnum.applyBlueprint), blueprints.applyYAMLBlueprint ); diff --git a/server/routers/integration.ts b/server/routers/integration.ts index 9bb26398..59ed253f 100644 --- a/server/routers/integration.ts +++ b/server/routers/integration.ts @@ -146,7 +146,7 @@ authenticated.get( ); // Site Resource endpoints authenticated.put( - "/org/:orgId/private-resource", + "/org/:orgId/site-resource", verifyApiKeyOrgAccess, verifyApiKeyHasAction(ActionsEnum.createSiteResource), logActionAudit(ActionsEnum.createSiteResource), diff --git a/src/app/[orgId]/settings/(private)/billing/page.tsx b/src/app/[orgId]/settings/(private)/billing/page.tsx index 75747d23..36bd8911 100644 --- a/src/app/[orgId]/settings/(private)/billing/page.tsx +++ b/src/app/[orgId]/settings/(private)/billing/page.tsx @@ -42,7 +42,11 @@ import { import { useTranslations } from "use-intl"; import Link from "next/link"; import { Tier } from "@server/types/Tiers"; -import { tier1LimitSet, tier2LimitSet, tier3LimitSet } from "@server/lib/billing/limitSet"; +import { + tier1LimitSet, + tier2LimitSet, + tier3LimitSet +} from "@server/lib/billing/limitSet"; import { FeatureId } from "@server/lib/billing/features"; // Plan tier definitions matching the mockup @@ -93,7 +97,10 @@ const planOptions: PlanOption[] = [ ]; // Tier limits mapping derived from limit sets -const tierLimits: Record = { +const tierLimits: Record< + Tier, + { users: number; sites: number; domains: number; remoteNodes: number } +> = { tier1: { users: tier1LimitSet[FeatureId.USERS]?.value ?? 0, sites: tier1LimitSet[FeatureId.SITES]?.value ?? 0, @@ -151,8 +158,8 @@ export default function BillingPage() { const [currentTier, setCurrentTier] = useState(null); // Usage IDs - const SITES = "sites"; const USERS = "users"; + const SITES = "sites"; const DOMAINS = "domains"; const REMOTE_EXIT_NODES = "remoteExitNodes"; @@ -297,8 +304,70 @@ export default function BillingPage() { await api.post(`/org/${org.org.orgId}/billing/change-tier`, { tier }); - // Refresh subscription data - window.location.reload(); + + // Poll the API to check if the tier change has been reflected + const pollForTierChange = async (targetTier: Tier) => { + const maxAttempts = 30; // 30 seconds with 1 second interval + let attempts = 0; + + const poll = async (): Promise => { + try { + const res = await api.get< + AxiosResponse + >(`/org/${org.org.orgId}/billing/subscriptions`); + const { subscriptions } = res.data.data; + + // Find tier subscription + const tierSub = subscriptions.find( + ({ subscription }) => + subscription?.type === "tier1" || + subscription?.type === "tier2" || + subscription?.type === "tier3" + ); + + // Check if the tier has changed to the target tier + if (tierSub?.subscription?.type === targetTier) { + return true; + } + + return false; + } catch (error) { + console.error("Error polling subscription:", error); + return false; + } + }; + + while (attempts < maxAttempts) { + const success = await poll(); + + if (success) { + // Tier change reflected, refresh the page + window.location.reload(); + return; + } + + attempts++; + + if (attempts < maxAttempts) { + // Wait 1 second before next poll + await new Promise((resolve) => + setTimeout(resolve, 1000) + ); + } + } + + // If we've exhausted all attempts, show an error + toast({ + title: "Tier change processing", + description: + "Your tier change is taking longer than expected. Please refresh the page in a moment to see the changes.", + variant: "destructive" + }); + setIsLoading(false); + }; + + // Start polling for the tier change + pollForTierChange(tier); } catch (error) { toast({ title: "Failed to change tier", @@ -323,8 +392,8 @@ export default function BillingPage() { } } - setShowConfirmDialog(false); - setPendingTier(null); + // setShowConfirmDialog(false); + // setPendingTier(null); }; const showTierConfirmation = ( @@ -338,7 +407,7 @@ export default function BillingPage() { }; const handleContactUs = () => { - window.open("mailto:sales@pangolin.net", "_blank"); + window.open("https://pangolin.net/talk-to-us", "_blank"); }; // Get current plan ID from tier @@ -393,7 +462,7 @@ export default function BillingPage() { plan.tierType, "downgrade", plan.name, - plan.price + (plan.priceDetail || "") + plan.price + (" " + plan.priceDetail || "") ); } else { handleModifySubscription(); @@ -412,7 +481,7 @@ export default function BillingPage() { plan.tierType, "upgrade", plan.name, - plan.price + (plan.priceDetail || "") + plan.price + (" " + plan.priceDetail || "") ); } else { handleModifySubscription(); @@ -438,17 +507,15 @@ export default function BillingPage() { // Calculate current usage cost for display const getUserCount = () => getUsageValue(USERS); const getPricePerUser = () => { - console.log( - "Calculating price per user, tierSubscription:", - tierSubscription - ); if (!tierSubscription?.items) return 0; // Find the subscription item for USERS feature const usersItem = tierSubscription.items.find( - (item) => item.planId === USERS + (item) => item.featureId === USERS ); + console.log("Users subscription item:", usersItem); + // unitAmount is in cents, convert to dollars if (usersItem?.unitAmount) { return usersItem.unitAmount / 100; @@ -529,6 +596,7 @@ export default function BillingPage() { disabled={ isLoading || planAction.disabled } + loading={isLoading && isCurrentPlan} > {planAction.label} @@ -736,9 +804,9 @@ export default function BillingPage() { { tierLimits[pendingTier.tier] - .sites + .users }{" "} - {t("billingSites") || "Sites"} + {t("billingUsers") || "Users"}
@@ -746,9 +814,9 @@ export default function BillingPage() { { tierLimits[pendingTier.tier] - .users + .sites }{" "} - {t("billingUsers") || "Users"} + {t("billingSites") || "Sites"}
@@ -787,14 +855,13 @@ export default function BillingPage() {