diff --git a/server/db/pg/schema/privateSchema.ts b/server/db/pg/schema/privateSchema.ts index 0512af22..2ebb145b 100644 --- a/server/db/pg/schema/privateSchema.ts +++ b/server/db/pg/schema/privateSchema.ts @@ -82,7 +82,8 @@ export const subscriptions = pgTable("subscriptions", { canceledAt: bigint("canceledAt", { mode: "number" }), createdAt: bigint("createdAt", { mode: "number" }).notNull(), updatedAt: bigint("updatedAt", { mode: "number" }), - billingCycleAnchor: bigint("billingCycleAnchor", { mode: "number" }) + billingCycleAnchor: bigint("billingCycleAnchor", { mode: "number" }), + type: varchar("type", { length: 50 }) // home_lab, starter, scale, or license }); export const subscriptionItems = pgTable("subscriptionItems", { diff --git a/server/db/sqlite/schema/privateSchema.ts b/server/db/sqlite/schema/privateSchema.ts index 2661ccdd..27979460 100644 --- a/server/db/sqlite/schema/privateSchema.ts +++ b/server/db/sqlite/schema/privateSchema.ts @@ -70,7 +70,8 @@ export const subscriptions = sqliteTable("subscriptions", { canceledAt: integer("canceledAt"), createdAt: integer("createdAt").notNull(), updatedAt: integer("updatedAt"), - billingCycleAnchor: integer("billingCycleAnchor") + billingCycleAnchor: integer("billingCycleAnchor"), + type: text("type") // home_lab, starter, scale, or license }); export const subscriptionItems = sqliteTable("subscriptionItems", { diff --git a/server/lib/billing/features.ts b/server/lib/billing/features.ts index 1215a829..a3ab0cc8 100644 --- a/server/lib/billing/features.ts +++ b/server/lib/billing/features.ts @@ -5,15 +5,16 @@ export enum FeatureId { SITES = "sites", EGRESS_DATA_MB = "egressDataMb", DOMAINS = "domains", - REMOTE_EXIT_NODES = "remoteExitNodes" + REMOTE_EXIT_NODES = "remoteExitNodes", + HOME_LAB = "home_lab" } -export const FeatureMeterIds: Partial> = { - [FeatureId.EGRESS_DATA_MB]: "mtr_61Srreh9eWrExDSCe41D3Ee2Ir7Wm5YW" +export const FeatureMeterIds: Partial> = { // right now we are not charging for any data + // [FeatureId.EGRESS_DATA_MB]: "mtr_61Srreh9eWrExDSCe41D3Ee2Ir7Wm5YW" }; export const FeatureMeterIdsSandbox: Partial> = { - [FeatureId.EGRESS_DATA_MB]: "mtr_test_61Snh2a2m6qome5Kv41DCpkOb237B3dQ" + // [FeatureId.EGRESS_DATA_MB]: "mtr_test_61Snh2a2m6qome5Kv41DCpkOb237B3dQ" }; export function getFeatureMeterId(featureId: FeatureId): string | undefined { @@ -37,12 +38,31 @@ export function getFeatureIdByMetricId( export type FeaturePriceSet = Partial>; +export const homeLabFeaturePriceSet: FeaturePriceSet = { + [FeatureId.HOME_LAB]: "price_1SxgpPDCpkOb237Bfo4rIsoT" +}; + +export const homeLabFeaturePriceSetSandbox: FeaturePriceSet = { + [FeatureId.HOME_LAB]: "price_1SxgpPDCpkOb237Bfo4rIsoT" +}; + +export function getHomeLabFeaturePriceSet(): FeaturePriceSet { + if ( + process.env.ENVIRONMENT == "prod" && + process.env.SANDBOX_MODE !== "true" + ) { + return homeLabFeaturePriceSet; + } else { + return homeLabFeaturePriceSetSandbox; + } +} + export const starterFeaturePriceSet: FeaturePriceSet = { - [FeatureId.USERS]: "price_1RrQeJD3Ee2Ir7WmgveP3xea" + [FeatureId.USERS]: "price_1SxaEHDCpkOb237BD9lBkPiR" }; export const starterFeaturePriceSetSandbox: FeaturePriceSet = { - [FeatureId.USERS]: "price_1ReNa4DCpkOb237Bc67G5muF" + [FeatureId.USERS]: "price_1SxaEHDCpkOb237BD9lBkPiR" }; export function getStarterFeaturePriceSet(): FeaturePriceSet { @@ -57,11 +77,11 @@ export function getStarterFeaturePriceSet(): FeaturePriceSet { } export const scaleFeaturePriceSet: FeaturePriceSet = { - [FeatureId.USERS]: "price_1RrQeJD3Ee2Ir7WmgveP3xea" + [FeatureId.USERS]: "price_1SxaEODCpkOb237BiXdCBSfs" }; export const scaleFeaturePriceSetSandbox: FeaturePriceSet = { - [FeatureId.USERS]: "price_1ReNa4DCpkOb237Bc67G5muF" + [FeatureId.USERS]: "price_1SxaEODCpkOb237BiXdCBSfs" }; export function getScaleFeaturePriceSet(): FeaturePriceSet { diff --git a/server/lib/billing/limitSet.ts b/server/lib/billing/limitSet.ts index a7a21809..12aea306 100644 --- a/server/lib/billing/limitSet.ts +++ b/server/lib/billing/limitSet.ts @@ -26,48 +26,59 @@ export const freeLimitSet: LimitSet = { [FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Free tier limit" } }; +export const homeLabLimitSet: LimitSet = { + [FeatureId.SITES]: { value: 3, description: "Home lab limit" }, // 1 site up for 32 days + [FeatureId.USERS]: { value: 3, description: "Home lab limit" }, + [FeatureId.EGRESS_DATA_MB]: { + value: 25000, + description: "Home lab limit" + }, // 25 GB + [FeatureId.DOMAINS]: { value: 3, description: "Home lab limit" }, + [FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Home lab limit" } +}; + export const starterLimitSet: LimitSet = { [FeatureId.SITES]: { value: 10, - description: "Contact us to increase soft limit." + description: "Starter limit" }, // 50 sites up for 31 days [FeatureId.USERS]: { value: 150, - description: "Contact us to increase soft limit." + description: "Starter limit" }, [FeatureId.EGRESS_DATA_MB]: { value: 12000000, - description: "Contact us to increase soft limit." + description: "Starter limit" }, // 12000 GB [FeatureId.DOMAINS]: { value: 250, - description: "Contact us to increase soft limit." + description: "Starter limit" }, [FeatureId.REMOTE_EXIT_NODES]: { value: 5, - description: "Contact us to increase soft limit." + description: "Starter limit" } }; export const scaleLimitSet: LimitSet = { [FeatureId.SITES]: { value: 10, - description: "Contact us to increase soft limit." + description: "Scale limit" }, // 50 sites up for 31 days [FeatureId.USERS]: { value: 150, - description: "Contact us to increase soft limit." + description: "Scale limit" }, [FeatureId.EGRESS_DATA_MB]: { value: 12000000, - description: "Contact us to increase soft limit." + description: "Scale limit" }, // 12000 GB [FeatureId.DOMAINS]: { value: 250, - description: "Contact us to increase soft limit." + description: "Scale limit" }, [FeatureId.REMOTE_EXIT_NODES]: { value: 5, - description: "Contact us to increase soft limit." + description: "Scale limit" } }; diff --git a/server/lib/billing/tiers.ts b/server/lib/billing/tiers.ts deleted file mode 100644 index ae49a48f..00000000 --- a/server/lib/billing/tiers.ts +++ /dev/null @@ -1,34 +0,0 @@ -export enum TierId { - STANDARD = "standard" -} - -export type TierPriceSet = { - [key in TierId]: string; -}; - -export const tierPriceSet: TierPriceSet = { - // Free tier matches the freeLimitSet - [TierId.STANDARD]: "price_1RrQ9cD3Ee2Ir7Wmqdy3KBa0" -}; - -export const tierPriceSetSandbox: TierPriceSet = { - // Free tier matches the freeLimitSet - // when matching tier the keys closer to 0 index are matched first so list the tiers in descending order of value - [TierId.STANDARD]: "price_1RrAYJDCpkOb237By2s1P32m" -}; - -export function getTierPriceSet( - environment?: string, - sandbox_mode?: boolean -): TierPriceSet { - if ( - (process.env.ENVIRONMENT == "prod" && - process.env.SANDBOX_MODE !== "true") || - (environment === "prod" && sandbox_mode !== true) - ) { - // THIS GETS LOADED CLIENT SIDE AND SERVER SIDE - return tierPriceSet; - } else { - return tierPriceSetSandbox; - } -} diff --git a/server/private/lib/billing/getOrgTierData.ts b/server/private/lib/billing/getOrgTierData.ts index adda2414..68e7ea2c 100644 --- a/server/private/lib/billing/getOrgTierData.ts +++ b/server/private/lib/billing/getOrgTierData.ts @@ -29,28 +29,5 @@ export async function getOrgTierData( const subscriptionsWithItems = await getOrgSubscriptionsData(orgId); - for (const { subscription, items } of subscriptionsWithItems) { - if (items && items.length > 0) { - const tierPriceSet = getTierPriceSet(); - // Iterate through tiers in order (earlier keys are higher tiers) - for (const [tierId, priceId] of Object.entries(tierPriceSet)) { - // Check if any subscription item matches this tier's price ID - const matchingItem = items.find((item) => item.priceId === priceId); - if (matchingItem) { - tier = tierId; - break; - } - } - } - - if (subscription && subscription.status === "active") { - active = true; - } - - // If we found a tier and active subscription, we can stop - if (tier && active) { - break; - } - } return { tier, active }; } diff --git a/server/private/routers/auth/index.ts b/server/private/routers/auth/index.ts index 535d5887..25adfa78 100644 --- a/server/private/routers/auth/index.ts +++ b/server/private/routers/auth/index.ts @@ -13,4 +13,3 @@ export * from "./transferSession"; export * from "./getSessionTransferToken"; -export * from "./quickStart"; diff --git a/server/private/routers/auth/quickStart.ts b/server/private/routers/auth/quickStart.ts deleted file mode 100644 index 612a3951..00000000 --- a/server/private/routers/auth/quickStart.ts +++ /dev/null @@ -1,585 +0,0 @@ -/* - * This file is part of a proprietary work. - * - * Copyright (c) 2025 Fossorial, Inc. - * All rights reserved. - * - * This file is licensed under the Fossorial Commercial License. - * You may not use this file except in compliance with the License. - * Unauthorized use, copying, modification, or distribution is strictly prohibited. - * - * This file is not licensed under the AGPLv3. - */ - -import { NextFunction, Request, Response } from "express"; -import { - account, - db, - domainNamespaces, - domains, - exitNodes, - newts, - newtSessions, - orgs, - passwordResetTokens, - Resource, - resourcePassword, - resourcePincode, - resources, - resourceWhitelist, - roleResources, - roles, - roleSites, - sites, - targetHealthCheck, - targets, - userResources, - userSites -} from "@server/db"; -import HttpCode from "@server/types/HttpCode"; -import { z } from "zod"; -import { users } from "@server/db"; -import { fromError } from "zod-validation-error"; -import createHttpError from "http-errors"; -import response from "@server/lib/response"; -import { SqliteError } from "better-sqlite3"; -import { eq, and, sql } from "drizzle-orm"; -import moment from "moment"; -import { generateId } from "@server/auth/sessions/app"; -import config from "@server/lib/config"; -import logger from "@server/logger"; -import { hashPassword } from "@server/auth/password"; -import { UserType } from "@server/types/UserTypes"; -import { createUserAccountOrg } from "@server/lib/createUserAccountOrg"; -import { sendEmail } from "@server/emails"; -import WelcomeQuickStart from "@server/emails/templates/WelcomeQuickStart"; -import { alphabet, generateRandomString } from "oslo/crypto"; -import { createDate, TimeSpan } from "oslo"; -import { getUniqueResourceName, getUniqueSiteName } from "@server/db/names"; -import { pickPort } from "@server/routers/target/helpers"; -import { addTargets } from "@server/routers/newt/targets"; -import { isTargetValid } from "@server/lib/validators"; -import { listExitNodes } from "#private/lib/exitNodes"; - -const bodySchema = z.object({ - email: z.email().toLowerCase(), - ip: z.string().refine(isTargetValid), - method: z.enum(["http", "https"]), - port: z.int().min(1).max(65535), - pincode: z - .string() - .regex(/^\d{6}$/) - .optional(), - password: z.string().min(4).max(100).optional(), - enableWhitelist: z.boolean().optional().default(true), - animalId: z.string() // This is actually the secret key for the backend -}); - -export type QuickStartBody = z.infer; - -export type QuickStartResponse = { - newtId: string; - newtSecret: string; - resourceUrl: string; - completeSignUpLink: string; -}; - -const DEMO_UBO_KEY = "b460293f-347c-4b30-837d-4e06a04d5a22"; - -export async function quickStart( - req: Request, - res: Response, - next: NextFunction -): Promise { - const parsedBody = bodySchema.safeParse(req.body); - - if (!parsedBody.success) { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - fromError(parsedBody.error).toString() - ) - ); - } - - const { - email, - ip, - method, - port, - pincode, - password, - enableWhitelist, - animalId - } = parsedBody.data; - - try { - const tokenValidation = validateTokenOnApi(animalId); - - if (!tokenValidation.isValid) { - logger.warn( - `Quick start failed for ${email} token ${animalId}: ${tokenValidation.message}` - ); - return next( - createHttpError( - HttpCode.BAD_REQUEST, - "Invalid or expired token" - ) - ); - } - - if (animalId === DEMO_UBO_KEY) { - if (email !== "mehrdad@getubo.com") { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - "Invalid email for demo Ubo key" - ) - ); - } - - const [existing] = await db - .select() - .from(users) - .where( - and( - eq(users.email, email), - eq(users.type, UserType.Internal) - ) - ); - - if (existing) { - // delete the user if it already exists - await db.delete(users).where(eq(users.userId, existing.userId)); - const orgId = `org_${existing.userId}`; - await db.delete(orgs).where(eq(orgs.orgId, orgId)); - } - } - - const tempPassword = generateId(15); - const passwordHash = await hashPassword(tempPassword); - const userId = generateId(15); - - // TODO: see if that user already exists? - - // Create the sandbox user - const existing = await db - .select() - .from(users) - .where( - and(eq(users.email, email), eq(users.type, UserType.Internal)) - ); - - if (existing && existing.length > 0) { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - "A user with that email address already exists" - ) - ); - } - - let newtId: string; - let secret: string; - let fullDomain: string; - let resource: Resource; - let completeSignUpLink: string; - - await db.transaction(async (trx) => { - await trx.insert(users).values({ - userId: userId, - type: UserType.Internal, - username: email, - email: email, - passwordHash, - dateCreated: moment().toISOString() - }); - - // create user"s account - await trx.insert(account).values({ - userId - }); - }); - - const { success, error, org } = await createUserAccountOrg( - userId, - email - ); - if (!success) { - if (error) { - throw new Error(error); - } - throw new Error("Failed to create user account and organization"); - } - if (!org) { - throw new Error("Failed to create user account and organization"); - } - - const orgId = org.orgId; - - await db.transaction(async (trx) => { - const token = generateRandomString( - 8, - alphabet("0-9", "A-Z", "a-z") - ); - - await trx - .delete(passwordResetTokens) - .where(eq(passwordResetTokens.userId, userId)); - - const tokenHash = await hashPassword(token); - - await trx.insert(passwordResetTokens).values({ - userId: userId, - email: email, - tokenHash, - expiresAt: createDate(new TimeSpan(7, "d")).getTime() - }); - - // // Create the sandbox newt - // const newClientAddress = await getNextAvailableClientSubnet(orgId); - // if (!newClientAddress) { - // throw new Error("No available subnet found"); - // } - - // const clientAddress = newClientAddress.split("/")[0]; - - newtId = generateId(15); - secret = generateId(48); - - // Create the sandbox site - const siteNiceId = await getUniqueSiteName(orgId); - const siteName = `First Site`; - - // pick a random exit node - const exitNodesList = await listExitNodes(orgId); - - // select a random exit node - const randomExitNode = - exitNodesList[Math.floor(Math.random() * exitNodesList.length)]; - - if (!randomExitNode) { - throw new Error("No exit nodes available"); - } - - const [newSite] = await trx - .insert(sites) - .values({ - orgId, - exitNodeId: randomExitNode.exitNodeId, - name: siteName, - niceId: siteNiceId, - // address: clientAddress, - type: "newt", - dockerSocketEnabled: true - }) - .returning(); - - const siteId = newSite.siteId; - - const adminRole = await trx - .select() - .from(roles) - .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) - .limit(1); - - if (adminRole.length === 0) { - throw new Error("Admin role not found"); - } - - await trx.insert(roleSites).values({ - roleId: adminRole[0].roleId, - siteId: newSite.siteId - }); - - if (req.user && req.userOrgRoleId != adminRole[0].roleId) { - // make sure the user can access the site - await trx.insert(userSites).values({ - userId: req.user?.userId!, - siteId: newSite.siteId - }); - } - - // add the peer to the exit node - const secretHash = await hashPassword(secret!); - - await trx.insert(newts).values({ - newtId: newtId!, - secretHash, - siteId: newSite.siteId, - dateCreated: moment().toISOString() - }); - - const [randomNamespace] = await trx - .select() - .from(domainNamespaces) - .orderBy(sql`RANDOM()`) - .limit(1); - - if (!randomNamespace) { - throw new Error("No domain namespace available"); - } - - const [randomNamespaceDomain] = await trx - .select() - .from(domains) - .where(eq(domains.domainId, randomNamespace.domainId)) - .limit(1); - - if (!randomNamespaceDomain) { - throw new Error("No domain found for the namespace"); - } - - const resourceNiceId = await getUniqueResourceName(orgId); - - // Create sandbox resource - const subdomain = `${resourceNiceId}-${generateId(5)}`; - fullDomain = `${subdomain}.${randomNamespaceDomain.baseDomain}`; - - const resourceName = `First Resource`; - - const newResource = await trx - .insert(resources) - .values({ - niceId: resourceNiceId, - fullDomain, - domainId: randomNamespaceDomain.domainId, - orgId, - name: resourceName, - subdomain, - http: true, - protocol: "tcp", - ssl: true, - sso: false, - emailWhitelistEnabled: enableWhitelist - }) - .returning(); - - await trx.insert(roleResources).values({ - roleId: adminRole[0].roleId, - resourceId: newResource[0].resourceId - }); - - if (req.user && req.userOrgRoleId != adminRole[0].roleId) { - // make sure the user can access the resource - await trx.insert(userResources).values({ - userId: req.user?.userId!, - resourceId: newResource[0].resourceId - }); - } - - resource = newResource[0]; - - // Create the sandbox target - const { internalPort, targetIps } = await pickPort(siteId!, trx); - - if (!internalPort) { - throw new Error("No available internal port"); - } - - const newTarget = await trx - .insert(targets) - .values({ - resourceId: resource.resourceId, - siteId: siteId!, - internalPort, - ip, - method, - port, - enabled: true - }) - .returning(); - - const newHealthcheck = await trx - .insert(targetHealthCheck) - .values({ - targetId: newTarget[0].targetId, - hcEnabled: false - }) - .returning(); - - // add the new target to the targetIps array - targetIps.push(`${ip}/32`); - - const [newt] = await trx - .select() - .from(newts) - .where(eq(newts.siteId, siteId!)) - .limit(1); - - await addTargets( - newt.newtId, - newTarget, - newHealthcheck, - resource.protocol - ); - - // Set resource pincode if provided - if (pincode) { - await trx - .delete(resourcePincode) - .where( - eq(resourcePincode.resourceId, resource!.resourceId) - ); - - const pincodeHash = await hashPassword(pincode); - - await trx.insert(resourcePincode).values({ - resourceId: resource!.resourceId, - pincodeHash, - digitLength: 6 - }); - } - - // Set resource password if provided - if (password) { - await trx - .delete(resourcePassword) - .where( - eq(resourcePassword.resourceId, resource!.resourceId) - ); - - const passwordHash = await hashPassword(password); - - await trx.insert(resourcePassword).values({ - resourceId: resource!.resourceId, - passwordHash - }); - } - - // Set resource OTP if whitelist is enabled - if (enableWhitelist) { - await trx.insert(resourceWhitelist).values({ - email, - resourceId: resource!.resourceId - }); - } - - completeSignUpLink = `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}&token=${token}`; - - // Store token for email outside transaction - await sendEmail( - WelcomeQuickStart({ - username: email, - link: completeSignUpLink, - fallbackLink: `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}`, - resourceMethod: method, - resourceHostname: ip, - resourcePort: port, - resourceUrl: `https://${fullDomain}`, - cliCommand: `newt --id ${newtId} --secret ${secret}` - }), - { - to: email, - from: config.getNoReplyEmail(), - subject: `Access your Pangolin dashboard and resources` - } - ); - }); - - return response(res, { - data: { - newtId: newtId!, - newtSecret: secret!, - resourceUrl: `https://${fullDomain!}`, - completeSignUpLink: completeSignUpLink! - }, - success: true, - error: false, - message: "Quick start completed successfully", - status: HttpCode.OK - }); - } catch (e) { - if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") { - if (config.getRawConfig().app.log_failed_attempts) { - logger.info( - `Account already exists with that email. Email: ${email}. IP: ${req.ip}.` - ); - } - return next( - createHttpError( - HttpCode.BAD_REQUEST, - "A user with that email address already exists" - ) - ); - } else { - logger.error(e); - return next( - createHttpError( - HttpCode.INTERNAL_SERVER_ERROR, - "Failed to do quick start" - ) - ); - } - } -} - -const BACKEND_SECRET_KEY = "4f9b6000-5d1a-11f0-9de7-ff2cc032f501"; - -/** - * Validates a token received from the frontend. - * @param {string} token The validation token from the request. - * @returns {{ isValid: boolean; message: string }} An object indicating if the token is valid. - */ -const validateTokenOnApi = ( - token: string -): { isValid: boolean; message: string } => { - if (token === DEMO_UBO_KEY) { - // Special case for demo UBO key - return { isValid: true, message: "Demo UBO key is valid." }; - } - - if (!token) { - return { isValid: false, message: "Error: No token provided." }; - } - - try { - // 1. Decode the base64 string - const decodedB64 = atob(token); - - // 2. Reverse the character code manipulation - const deobfuscated = decodedB64 - .split("") - .map((char) => String.fromCharCode(char.charCodeAt(0) - 5)) // Reverse the shift - .join(""); - - // 3. Split the data to get the original secret and timestamp - const parts = deobfuscated.split("|"); - if (parts.length !== 2) { - throw new Error("Invalid token format."); - } - const receivedKey = parts[0]; - const tokenTimestamp = parseInt(parts[1], 10); - - // 4. Check if the secret key matches - if (receivedKey !== BACKEND_SECRET_KEY) { - return { isValid: false, message: "Invalid token: Key mismatch." }; - } - - // 5. Check if the timestamp is recent (e.g., within 30 seconds) to prevent replay attacks - const now = Date.now(); - const timeDifference = now - tokenTimestamp; - - if (timeDifference > 30000) { - // 30 seconds - return { isValid: false, message: "Invalid token: Expired." }; - } - - if (timeDifference < 0) { - // Timestamp is in the future - return { - isValid: false, - message: "Invalid token: Timestamp is in the future." - }; - } - - // If all checks pass, the token is valid - return { isValid: true, message: "Token is valid!" }; - } catch (error) { - // This will catch errors from atob (if not valid base64) or other issues. - return { - isValid: false, - message: `Error: ${(error as Error).message}` - }; - } -}; diff --git a/server/private/routers/billing/changeTier.ts b/server/private/routers/billing/changeTier.ts new file mode 100644 index 00000000..68a90c92 --- /dev/null +++ b/server/private/routers/billing/changeTier.ts @@ -0,0 +1,263 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + +import { Request, Response, NextFunction } from "express"; +import { z } from "zod"; +import { customers, db, subscriptions, subscriptionItems } from "@server/db"; +import { eq, and, or } from "drizzle-orm"; +import response from "@server/lib/response"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { fromError } from "zod-validation-error"; +import stripe from "#private/lib/stripe"; +import { + getHomeLabFeaturePriceSet, + getScaleFeaturePriceSet, + getStarterFeaturePriceSet, + FeatureId, + type FeaturePriceSet +} from "@server/lib/billing"; + +const changeTierSchema = z.strictObject({ + orgId: z.string() +}); + +const changeTierBodySchema = z.strictObject({ + tier: z.enum(["home_lab", "starter", "scale"]) +}); + +export async function changeTier( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedParams = changeTierSchema.safeParse(req.params); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const { orgId } = parsedParams.data; + + const parsedBody = changeTierBodySchema.safeParse(req.body); + if (!parsedBody.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedBody.error).toString() + ) + ); + } + + const { tier } = parsedBody.data; + + // Get the customer for this org + const [customer] = await db + .select() + .from(customers) + .where(eq(customers.orgId, orgId)) + .limit(1); + + if (!customer) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "No customer found for this organization" + ) + ); + } + + // Get the active subscription for this customer + const [subscription] = await db + .select() + .from(subscriptions) + .where( + and( + eq(subscriptions.customerId, customer.customerId), + eq(subscriptions.status, "active"), + or( + eq(subscriptions.type, "home_lab"), + eq(subscriptions.type, "starter"), + eq(subscriptions.type, "scale") + ) + ) + ) + .limit(1); + + if (!subscription) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "No active subscription found for this organization" + ) + ); + } + + // Get the target tier's price set + let targetPriceSet: FeaturePriceSet; + if (tier === "home_lab") { + targetPriceSet = getHomeLabFeaturePriceSet(); + } else if (tier === "starter") { + targetPriceSet = getStarterFeaturePriceSet(); + } else if (tier === "scale") { + targetPriceSet = getScaleFeaturePriceSet(); + } else { + return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid tier")); + } + + // Get current subscription items from our database + const currentItems = await db + .select() + .from(subscriptionItems) + .where( + eq( + subscriptionItems.subscriptionId, + subscription.subscriptionId + ) + ); + + if (currentItems.length === 0) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "No subscription items found" + ) + ); + } + + // Retrieve the full subscription from Stripe to get item IDs + const stripeSubscription = await stripe!.subscriptions.retrieve( + subscription.subscriptionId + ); + + // Determine if we're switching between different products + // home_lab uses HOME_LAB product, starter/scale use USERS product + const currentTier = subscription.type; + const switchingProducts = + (currentTier === "home_lab" && (tier === "starter" || tier === "scale")) || + ((currentTier === "starter" || currentTier === "scale") && tier === "home_lab"); + + let updatedSubscription; + + if (switchingProducts) { + // When switching between different products, we need to: + // 1. Delete old subscription items + // 2. Add new subscription items + logger.info( + `Switching products from ${currentTier} to ${tier} for subscription ${subscription.subscriptionId}` + ); + + // Build array to delete all existing items and add new ones + const itemsToUpdate: any[] = []; + + // Mark all existing items for deletion + for (const stripeItem of stripeSubscription.items.data) { + itemsToUpdate.push({ + id: stripeItem.id, + deleted: true + }); + } + + // Add new items for the target tier + for (const [featureId, priceId] of Object.entries(targetPriceSet)) { + itemsToUpdate.push({ + price: priceId + }); + } + + updatedSubscription = await stripe!.subscriptions.update( + subscription.subscriptionId, + { + items: itemsToUpdate, + proration_behavior: "create_prorations" + } + ); + } else { + // Same product, different price tier (starter <-> scale) + // We can simply update the price + logger.info( + `Updating price from ${currentTier} to ${tier} for subscription ${subscription.subscriptionId}` + ); + + const itemsToUpdate = stripeSubscription.items.data.map( + (stripeItem) => { + // Find the corresponding item in our database + const dbItem = currentItems.find( + (item) => item.priceId === stripeItem.price.id + ); + + if (!dbItem) { + // Keep the existing item unchanged if we can't find it + return { + id: stripeItem.id, + price: stripeItem.price.id + }; + } + + // Map to the corresponding feature in the new tier + const newPriceId = targetPriceSet[FeatureId.USERS]; + + if (newPriceId) { + return { + id: stripeItem.id, + price: newPriceId + }; + } + + // If no mapping found, keep existing + return { + id: stripeItem.id, + price: stripeItem.price.id + }; + } + ); + + updatedSubscription = await stripe!.subscriptions.update( + subscription.subscriptionId, + { + items: itemsToUpdate, + proration_behavior: "create_prorations" + } + ); + } + + logger.info( + `Successfully changed tier to ${tier} for org ${orgId}, subscription ${subscription.subscriptionId}` + ); + + return response<{ subscriptionId: string; newTier: string }>(res, { + data: { + subscriptionId: updatedSubscription.id, + newTier: tier + }, + success: true, + error: false, + message: "Tier change successful", + status: HttpCode.OK + }); + } catch (error) { + logger.error("Error changing tier:", error); + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "An error occurred while changing tier" + ) + ); + } +} diff --git a/server/private/routers/billing/createCheckoutSessionSAAS.ts b/server/private/routers/billing/createCheckoutSession.ts similarity index 72% rename from server/private/routers/billing/createCheckoutSessionSAAS.ts rename to server/private/routers/billing/createCheckoutSession.ts index 0f9b783e..7c79b8fb 100644 --- a/server/private/routers/billing/createCheckoutSessionSAAS.ts +++ b/server/private/routers/billing/createCheckoutSession.ts @@ -22,13 +22,16 @@ import logger from "@server/logger"; import config from "@server/lib/config"; import { fromError } from "zod-validation-error"; import stripe from "#private/lib/stripe"; -import { getLineItems, getStandardFeaturePriceSet } from "@server/lib/billing"; -import { getTierPriceSet, TierId } from "@server/lib/billing/tiers"; +import { getHomeLabFeaturePriceSet, getLineItems, getScaleFeaturePriceSet, getStarterFeaturePriceSet } from "@server/lib/billing"; const createCheckoutSessionSchema = z.strictObject({ orgId: z.string() }); +const createCheckoutSessionBodySchema = z.strictObject({ + tier: z.enum(["home_lab", "starter", "scale"]), +}); + export async function createCheckoutSessionSAAS( req: Request, res: Response, @@ -47,6 +50,18 @@ export async function createCheckoutSessionSAAS( const { orgId } = parsedParams.data; + const parsedBody = createCheckoutSessionBodySchema.safeParse(req.body); + if (!parsedBody.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedBody.error).toString() + ) + ); + } + + const { tier } = parsedBody.data; + // check if we already have a customer for this org const [customer] = await db .select() @@ -65,18 +80,23 @@ export async function createCheckoutSessionSAAS( ); } - const standardTierPrice = getTierPriceSet()[TierId.STANDARD]; + let lineItems; + if (tier === "home_lab") { + lineItems = getLineItems(getHomeLabFeaturePriceSet()); + } else if (tier === "starter") { + lineItems = getLineItems(getStarterFeaturePriceSet()); + } else if (tier === "scale") { + lineItems = getLineItems(getScaleFeaturePriceSet()); + } else { + return next( + createHttpError(HttpCode.BAD_REQUEST, "Invalid plan") + ); + } const session = await stripe!.checkout.sessions.create({ client_reference_id: orgId, // So we can look it up the org later on the webhook billing_address_collection: "required", - line_items: [ - { - price: standardTierPrice, // Use the standard tier - quantity: 1 - }, - ...getLineItems(getStandardFeaturePriceSet()) - ], // Start with the standard feature set that matches the free limits + line_items: lineItems, customer: customer.customerId, mode: "subscription", success_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing?success=true&session_id={CHECKOUT_SESSION_ID}`, diff --git a/server/private/routers/billing/hooks/getSubType.ts b/server/private/routers/billing/hooks/getSubType.ts index 8cd07713..3618747d 100644 --- a/server/private/routers/billing/hooks/getSubType.ts +++ b/server/private/routers/billing/hooks/getSubType.ts @@ -1,35 +1,61 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + import { getLicensePriceSet, } from "@server/lib/billing/licenses"; import { - getTierPriceSet, -} from "@server/lib/billing/tiers"; + getHomeLabFeaturePriceSet, + getStarterFeaturePriceSet, + getScaleFeaturePriceSet, +} from "@server/lib/billing/features"; import Stripe from "stripe"; -export function getSubType(fullSubscription: Stripe.Response): "saas" | "license" { +export type SubscriptionType = "home_lab" | "starter" | "scale" | "license"; + +export function getSubType(fullSubscription: Stripe.Response): SubscriptionType | null { // Determine subscription type by checking subscription items - let type: "saas" | "license" = "saas"; - if (Array.isArray(fullSubscription.items?.data)) { - for (const item of fullSubscription.items.data) { - const priceId = item.price.id; + if (!Array.isArray(fullSubscription.items?.data) || fullSubscription.items.data.length === 0) { + return null; + } - // Check if price ID matches any license price - const licensePrices = Object.values(getLicensePriceSet()); + for (const item of fullSubscription.items.data) { + const priceId = item.price.id; - if (licensePrices.includes(priceId)) { - type = "license"; - break; - } + // Check if price ID matches any license price + const licensePrices = Object.values(getLicensePriceSet()); + if (licensePrices.includes(priceId)) { + return "license"; + } - // Check if price ID matches any tier price (saas) - const tierPrices = Object.values(getTierPriceSet()); + // Check if price ID matches home lab tier + const homeLabPrices = Object.values(getHomeLabFeaturePriceSet()); + if (homeLabPrices.includes(priceId)) { + return "home_lab"; + } - if (tierPrices.includes(priceId)) { - type = "saas"; - break; - } + // Check if price ID matches starter tier + const starterPrices = Object.values(getStarterFeaturePriceSet()); + if (starterPrices.includes(priceId)) { + return "starter"; + } + + // Check if price ID matches scale tier + const scalePrices = Object.values(getScaleFeaturePriceSet()); + if (scalePrices.includes(priceId)) { + return "scale"; } } - return type; -} + return null; +} \ No newline at end of file diff --git a/server/private/routers/billing/hooks/handleSubscriptionCreated.ts b/server/private/routers/billing/hooks/handleSubscriptionCreated.ts index a51f825f..6238e65c 100644 --- a/server/private/routers/billing/hooks/handleSubscriptionCreated.ts +++ b/server/private/routers/billing/hooks/handleSubscriptionCreated.ts @@ -59,6 +59,8 @@ export async function handleSubscriptionCreated( return; } + const type = getSubType(fullSubscription); + const newSubscription = { subscriptionId: subscription.id, customerId: subscription.customer as string, @@ -66,7 +68,8 @@ export async function handleSubscriptionCreated( canceledAt: subscription.canceled_at ? subscription.canceled_at : null, - createdAt: subscription.created + createdAt: subscription.created, + type: type }; await db.insert(subscriptions).values(newSubscription); @@ -129,10 +132,9 @@ export async function handleSubscriptionCreated( return; } - const type = getSubType(fullSubscription); - if (type === "saas") { + if (type === "home_lab" || type === "starter" || type === "scale") { logger.debug( - `Handling SAAS subscription lifecycle for org ${customer.orgId}` + `Handling SAAS subscription lifecycle for org ${customer.orgId} with type ${type}` ); // we only need to handle the limit lifecycle for saas subscriptions not for the licenses await handleSubscriptionLifesycle( diff --git a/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts b/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts index 21943354..6b21fb26 100644 --- a/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts +++ b/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts @@ -64,6 +64,8 @@ export async function handleSubscriptionUpdated( .where(eq(customers.customerId, subscription.customer as string)) .limit(1); + const type = getSubType(fullSubscription); + await db .update(subscriptions) .set({ @@ -72,7 +74,8 @@ export async function handleSubscriptionUpdated( ? subscription.canceled_at : null, updatedAt: Math.floor(Date.now() / 1000), - billingCycleAnchor: subscription.billing_cycle_anchor + billingCycleAnchor: subscription.billing_cycle_anchor, + type: type }) .where(eq(subscriptions.subscriptionId, subscription.id)); @@ -234,17 +237,16 @@ export async function handleSubscriptionUpdated( } // --- end usage update --- - const type = getSubType(fullSubscription); - if (type === "saas") { + if (type === "home_lab" || type === "starter" || type === "scale") { logger.debug( - `Handling SAAS subscription lifecycle for org ${customer.orgId}` + `Handling SAAS subscription lifecycle for org ${customer.orgId} with type ${type}` ); // we only need to handle the limit lifecycle for saas subscriptions not for the licenses await handleSubscriptionLifesycle( customer.orgId, subscription.status ); - } else { + } else if (type === "license") { if (subscription.status === "canceled" || subscription.status == "unpaid" || subscription.status == "incomplete_expired") { try { // WARNING: diff --git a/server/private/routers/billing/index.ts b/server/private/routers/billing/index.ts index e7770ec2..b7bf02d4 100644 --- a/server/private/routers/billing/index.ts +++ b/server/private/routers/billing/index.ts @@ -11,7 +11,7 @@ * This file is not licensed under the AGPLv3. */ -export * from "./createCheckoutSessionSAAS"; +export * from "./createCheckoutSession"; export * from "./createPortalSession"; export * from "./getOrgSubscriptions"; export * from "./getOrgUsage"; diff --git a/server/private/routers/external.ts b/server/private/routers/external.ts index ddc2afe0..962cede9 100644 --- a/server/private/routers/external.ts +++ b/server/private/routers/external.ts @@ -141,25 +141,8 @@ authenticated.post( ); if (build === "saas") { - unauthenticated.post( - "/quick-start", - rateLimit({ - windowMs: 15 * 60 * 1000, - max: 100, - keyGenerator: (req) => req.path, - handler: (req, res, next) => { - const message = `We're too busy right now. Please try again later.`; - return next( - createHttpError(HttpCode.TOO_MANY_REQUESTS, message) - ); - }, - store: createStore() - }), - auth.quickStart - ); - authenticated.post( - "/org/:orgId/billing/create-checkout-session-saas", + "/org/:orgId/billing/create-checkout-session", verifyOrgAccess, verifyUserHasAction(ActionsEnum.billing), logActionAudit(ActionsEnum.billing), diff --git a/src/app/[orgId]/settings/(private)/billing/page.tsx b/src/app/[orgId]/settings/(private)/billing/page.tsx index 7b10dc4c..75611f35 100644 --- a/src/app/[orgId]/settings/(private)/billing/page.tsx +++ b/src/app/[orgId]/settings/(private)/billing/page.tsx @@ -153,7 +153,7 @@ export default function GeneralPage() { setIsLoading(true); try { const response = await api.post>( - `/org/${org.org.orgId}/billing/create-checkout-session-saas`, + `/org/${org.org.orgId}/billing/create-checkout-session`, {} ); console.log("Checkout session response:", response.data);