mirror of
https://github.com/fosrl/pangolin.git
synced 2026-02-16 09:56:36 +00:00
Merge branch 'dev' into refactor/paginated-tables
This commit is contained in:
@@ -11,46 +11,59 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { getTierPriceSet } from "@server/lib/billing/tiers";
|
||||
import { getOrgSubscriptionsData } from "@server/private/routers/billing/getOrgSubscriptions";
|
||||
import { build } from "@server/build";
|
||||
import { db, customers, subscriptions } from "@server/db";
|
||||
import { Tier } from "@server/types/Tiers";
|
||||
import { eq, and, ne } from "drizzle-orm";
|
||||
|
||||
export async function getOrgTierData(
|
||||
orgId: string
|
||||
): Promise<{ tier: string | null; active: boolean }> {
|
||||
let tier = null;
|
||||
): Promise<{ tier: Tier | null; active: boolean }> {
|
||||
let tier: Tier | null = null;
|
||||
let active = false;
|
||||
|
||||
if (build !== "saas") {
|
||||
return { tier, active };
|
||||
}
|
||||
|
||||
// TODO: THIS IS INEFFICIENT!!! WE SHOULD IMPROVE HOW WE STORE TIERS WITH SUBSCRIPTIONS AND RETRIEVE THEM
|
||||
try {
|
||||
// Get customer for org
|
||||
const [customer] = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
const subscriptionsWithItems = await getOrgSubscriptionsData(orgId);
|
||||
if (customer) {
|
||||
// Query for active subscriptions that are not license type
|
||||
const [subscription] = await db
|
||||
.select()
|
||||
.from(subscriptions)
|
||||
.where(
|
||||
and(
|
||||
eq(subscriptions.customerId, customer.customerId),
|
||||
eq(subscriptions.status, "active"),
|
||||
ne(subscriptions.type, "license")
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
for (const { subscription, items } of subscriptionsWithItems) {
|
||||
if (items && items.length > 0) {
|
||||
const tierPriceSet = getTierPriceSet();
|
||||
// Iterate through tiers in order (earlier keys are higher tiers)
|
||||
for (const [tierId, priceId] of Object.entries(tierPriceSet)) {
|
||||
// Check if any subscription item matches this tier's price ID
|
||||
const matchingItem = items.find((item) => item.priceId === priceId);
|
||||
if (matchingItem) {
|
||||
tier = tierId;
|
||||
break;
|
||||
if (subscription) {
|
||||
// Validate that subscription.type is one of the expected tier values
|
||||
if (
|
||||
subscription.type === "tier1" ||
|
||||
subscription.type === "tier2" ||
|
||||
subscription.type === "tier3"
|
||||
) {
|
||||
tier = subscription.type;
|
||||
active = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (subscription && subscription.status === "active") {
|
||||
active = true;
|
||||
}
|
||||
|
||||
// If we found a tier and active subscription, we can stop
|
||||
if (tier && active) {
|
||||
break;
|
||||
}
|
||||
} catch (error) {
|
||||
// If org not found or error occurs, return null tier and inactive
|
||||
// This is acceptable behavior as per the function signature
|
||||
}
|
||||
|
||||
return { tier, active };
|
||||
}
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
|
||||
import { build } from "@server/build";
|
||||
import { db, Org, orgs, ResourceSession, sessions, users } from "@server/db";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import license from "#private/license/license";
|
||||
import { eq } from "drizzle-orm";
|
||||
import {
|
||||
@@ -80,6 +78,8 @@ export async function checkOrgAccessPolicy(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: check that the org is subscribed
|
||||
|
||||
// get the needed data
|
||||
|
||||
if (!props.org) {
|
||||
|
||||
@@ -65,6 +65,11 @@ export class PrivateConfig {
|
||||
this.rawPrivateConfig.branding?.logo?.dark_path || undefined;
|
||||
}
|
||||
|
||||
if (this.rawPrivateConfig.app.identity_provider_mode) {
|
||||
process.env.IDENTITY_PROVIDER_MODE =
|
||||
this.rawPrivateConfig.app.identity_provider_mode;
|
||||
}
|
||||
|
||||
process.env.BRANDING_LOGO_AUTH_WIDTH = this.rawPrivateConfig.branding
|
||||
?.logo?.auth_page?.width
|
||||
? this.rawPrivateConfig.branding?.logo?.auth_page?.width.toString()
|
||||
@@ -125,24 +130,10 @@ export class PrivateConfig {
|
||||
this.rawPrivateConfig.server.reo_client_id;
|
||||
}
|
||||
|
||||
if (this.rawPrivateConfig.stripe?.s3Bucket) {
|
||||
process.env.S3_BUCKET = this.rawPrivateConfig.stripe.s3Bucket;
|
||||
}
|
||||
if (this.rawPrivateConfig.stripe?.localFilePath) {
|
||||
process.env.LOCAL_FILE_PATH =
|
||||
this.rawPrivateConfig.stripe.localFilePath;
|
||||
}
|
||||
if (this.rawPrivateConfig.stripe?.s3Region) {
|
||||
process.env.S3_REGION = this.rawPrivateConfig.stripe.s3Region;
|
||||
}
|
||||
if (this.rawPrivateConfig.flags.use_pangolin_dns) {
|
||||
process.env.USE_PANGOLIN_DNS =
|
||||
this.rawPrivateConfig.flags.use_pangolin_dns.toString();
|
||||
}
|
||||
if (this.rawPrivateConfig.flags.use_org_only_idp) {
|
||||
process.env.USE_ORG_ONLY_IDP =
|
||||
this.rawPrivateConfig.flags.use_org_only_idp.toString();
|
||||
}
|
||||
}
|
||||
|
||||
public getRawPrivateConfig() {
|
||||
|
||||
@@ -13,18 +13,20 @@
|
||||
|
||||
import { build } from "@server/build";
|
||||
import license from "#private/license/license";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { isSubscribed } from "#private/lib/isSubscribed";
|
||||
import { Tier } from "@server/types/Tiers";
|
||||
|
||||
export async function isLicensedOrSubscribed(orgId: string): Promise<boolean> {
|
||||
export async function isLicensedOrSubscribed(
|
||||
orgId: string,
|
||||
tiers: Tier[]
|
||||
): Promise<boolean> {
|
||||
if (build === "enterprise") {
|
||||
return await license.isUnlocked();
|
||||
}
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
return tier === TierId.STANDARD;
|
||||
return isSubscribed(orgId, tiers);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
29
server/private/lib/isSubscribed.ts
Normal file
29
server/private/lib/isSubscribed.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { build } from "@server/build";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { Tier } from "@server/types/Tiers";
|
||||
|
||||
export async function isSubscribed(
|
||||
orgId: string,
|
||||
tiers: Tier[]
|
||||
): Promise<boolean> {
|
||||
if (build === "saas") {
|
||||
const { tier, active } = await getOrgTierData(orgId);
|
||||
const isTier = (tier && tiers.includes(tier)) || false;
|
||||
return active && isTier;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
@@ -25,7 +25,8 @@ export const privateConfigSchema = z.object({
|
||||
app: z
|
||||
.object({
|
||||
region: z.string().optional().default("default"),
|
||||
base_domain: z.string().optional()
|
||||
base_domain: z.string().optional(),
|
||||
identity_provider_mode: z.enum(["global", "org"]).optional()
|
||||
})
|
||||
.optional()
|
||||
.default({
|
||||
@@ -95,7 +96,7 @@ export const privateConfigSchema = z.object({
|
||||
.object({
|
||||
enable_redis: z.boolean().optional().default(false),
|
||||
use_pangolin_dns: z.boolean().optional().default(false),
|
||||
use_org_only_idp: z.boolean().optional().default(false)
|
||||
use_org_only_idp: z.boolean().optional()
|
||||
})
|
||||
.optional()
|
||||
.prefault({}),
|
||||
@@ -176,12 +177,34 @@ export const privateConfigSchema = z.object({
|
||||
.string()
|
||||
.optional()
|
||||
.transform(getEnvOrYaml("STRIPE_WEBHOOK_SECRET")),
|
||||
s3Bucket: z.string(),
|
||||
s3Region: z.string().default("us-east-1"),
|
||||
localFilePath: z.string()
|
||||
// s3Bucket: z.string(),
|
||||
// s3Region: z.string().default("us-east-1"),
|
||||
// localFilePath: z.string().optional()
|
||||
})
|
||||
.optional()
|
||||
});
|
||||
})
|
||||
.transform((data) => {
|
||||
// this to maintain backwards compatibility with the old config file
|
||||
const identityProviderMode = data.app?.identity_provider_mode;
|
||||
const useOrgOnlyIdp = data.flags?.use_org_only_idp;
|
||||
|
||||
if (identityProviderMode !== undefined) {
|
||||
return data;
|
||||
}
|
||||
if (useOrgOnlyIdp === true) {
|
||||
return {
|
||||
...data,
|
||||
app: { ...data.app, identity_provider_mode: "org" as const }
|
||||
};
|
||||
}
|
||||
if (useOrgOnlyIdp === false) {
|
||||
return {
|
||||
...data,
|
||||
app: { ...data.app, identity_provider_mode: "global" as const }
|
||||
};
|
||||
}
|
||||
return data;
|
||||
});
|
||||
|
||||
export function readPrivateConfigFile() {
|
||||
if (build == "oss") {
|
||||
|
||||
@@ -16,46 +16,61 @@ import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { build } from "@server/build";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { Tier } from "@server/types/Tiers";
|
||||
|
||||
export function verifyValidSubscription(tiers: Tier[]) {
|
||||
return async function (
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
if (build != "saas") {
|
||||
return next();
|
||||
}
|
||||
|
||||
const orgId =
|
||||
req.params.orgId ||
|
||||
req.body.orgId ||
|
||||
req.query.orgId ||
|
||||
req.userOrgId;
|
||||
|
||||
if (!orgId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Organization ID is required to verify subscription"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { tier, active } = await getOrgTierData(orgId);
|
||||
const isTier = tiers.includes(tier as Tier);
|
||||
if (!active) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Organization does not have an active subscription"
|
||||
)
|
||||
);
|
||||
}
|
||||
if (!isTier) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Organization subscription tier does not have access to this feature"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
export async function verifyValidSubscription(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
) {
|
||||
try {
|
||||
if (build != "saas") {
|
||||
return next();
|
||||
}
|
||||
|
||||
const orgId = req.params.orgId || req.body.orgId || req.query.orgId || req.userOrgId;
|
||||
|
||||
if (!orgId) {
|
||||
} catch (e) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Organization ID is required to verify subscription"
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Error verifying subscription"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const tier = await getOrgTierData(orgId);
|
||||
|
||||
if (!tier.active) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Organization does not have an active subscription"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
return next();
|
||||
} catch (e) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Error verifying subscription"
|
||||
)
|
||||
);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -19,8 +19,6 @@ import { fromError } from "zod-validation-error";
|
||||
|
||||
import type { Request, Response, NextFunction } from "express";
|
||||
import { build } from "@server/build";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import {
|
||||
approvals,
|
||||
clients,
|
||||
@@ -273,19 +271,6 @@ export async function listApprovals(
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const { approvalsList, nextCursorPending, nextCursorTimestamp } =
|
||||
await queryApprovals({
|
||||
orgId: orgId.toString(),
|
||||
|
||||
@@ -17,10 +17,7 @@ import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
|
||||
import { build } from "@server/build";
|
||||
import { approvals, clients, db, orgs, type Approval } from "@server/db";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import response from "@server/lib/response";
|
||||
import { and, eq, type InferInsertModel } from "drizzle-orm";
|
||||
import type { NextFunction, Request, Response } from "express";
|
||||
@@ -64,20 +61,6 @@ export async function processPendingApproval(
|
||||
}
|
||||
|
||||
const { orgId, approvalId } = parsedParams.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const updateData = parsedBody.data;
|
||||
|
||||
const approval = await db
|
||||
|
||||
@@ -13,4 +13,3 @@
|
||||
|
||||
export * from "./transferSession";
|
||||
export * from "./getSessionTransferToken";
|
||||
export * from "./quickStart";
|
||||
|
||||
@@ -1,585 +0,0 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import {
|
||||
account,
|
||||
db,
|
||||
domainNamespaces,
|
||||
domains,
|
||||
exitNodes,
|
||||
newts,
|
||||
newtSessions,
|
||||
orgs,
|
||||
passwordResetTokens,
|
||||
Resource,
|
||||
resourcePassword,
|
||||
resourcePincode,
|
||||
resources,
|
||||
resourceWhitelist,
|
||||
roleResources,
|
||||
roles,
|
||||
roleSites,
|
||||
sites,
|
||||
targetHealthCheck,
|
||||
targets,
|
||||
userResources,
|
||||
userSites
|
||||
} from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { z } from "zod";
|
||||
import { users } from "@server/db";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
import { SqliteError } from "better-sqlite3";
|
||||
import { eq, and, sql } from "drizzle-orm";
|
||||
import moment from "moment";
|
||||
import { generateId } from "@server/auth/sessions/app";
|
||||
import config from "@server/lib/config";
|
||||
import logger from "@server/logger";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
import { createUserAccountOrg } from "@server/lib/createUserAccountOrg";
|
||||
import { sendEmail } from "@server/emails";
|
||||
import WelcomeQuickStart from "@server/emails/templates/WelcomeQuickStart";
|
||||
import { alphabet, generateRandomString } from "oslo/crypto";
|
||||
import { createDate, TimeSpan } from "oslo";
|
||||
import { getUniqueResourceName, getUniqueSiteName } from "@server/db/names";
|
||||
import { pickPort } from "@server/routers/target/helpers";
|
||||
import { addTargets } from "@server/routers/newt/targets";
|
||||
import { isTargetValid } from "@server/lib/validators";
|
||||
import { listExitNodes } from "#private/lib/exitNodes";
|
||||
|
||||
const bodySchema = z.object({
|
||||
email: z.email().toLowerCase(),
|
||||
ip: z.string().refine(isTargetValid),
|
||||
method: z.enum(["http", "https"]),
|
||||
port: z.int().min(1).max(65535),
|
||||
pincode: z
|
||||
.string()
|
||||
.regex(/^\d{6}$/)
|
||||
.optional(),
|
||||
password: z.string().min(4).max(100).optional(),
|
||||
enableWhitelist: z.boolean().optional().default(true),
|
||||
animalId: z.string() // This is actually the secret key for the backend
|
||||
});
|
||||
|
||||
export type QuickStartBody = z.infer<typeof bodySchema>;
|
||||
|
||||
export type QuickStartResponse = {
|
||||
newtId: string;
|
||||
newtSecret: string;
|
||||
resourceUrl: string;
|
||||
completeSignUpLink: string;
|
||||
};
|
||||
|
||||
const DEMO_UBO_KEY = "b460293f-347c-4b30-837d-4e06a04d5a22";
|
||||
|
||||
export async function quickStart(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const {
|
||||
email,
|
||||
ip,
|
||||
method,
|
||||
port,
|
||||
pincode,
|
||||
password,
|
||||
enableWhitelist,
|
||||
animalId
|
||||
} = parsedBody.data;
|
||||
|
||||
try {
|
||||
const tokenValidation = validateTokenOnApi(animalId);
|
||||
|
||||
if (!tokenValidation.isValid) {
|
||||
logger.warn(
|
||||
`Quick start failed for ${email} token ${animalId}: ${tokenValidation.message}`
|
||||
);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Invalid or expired token"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (animalId === DEMO_UBO_KEY) {
|
||||
if (email !== "mehrdad@getubo.com") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Invalid email for demo Ubo key"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [existing] = await db
|
||||
.select()
|
||||
.from(users)
|
||||
.where(
|
||||
and(
|
||||
eq(users.email, email),
|
||||
eq(users.type, UserType.Internal)
|
||||
)
|
||||
);
|
||||
|
||||
if (existing) {
|
||||
// delete the user if it already exists
|
||||
await db.delete(users).where(eq(users.userId, existing.userId));
|
||||
const orgId = `org_${existing.userId}`;
|
||||
await db.delete(orgs).where(eq(orgs.orgId, orgId));
|
||||
}
|
||||
}
|
||||
|
||||
const tempPassword = generateId(15);
|
||||
const passwordHash = await hashPassword(tempPassword);
|
||||
const userId = generateId(15);
|
||||
|
||||
// TODO: see if that user already exists?
|
||||
|
||||
// Create the sandbox user
|
||||
const existing = await db
|
||||
.select()
|
||||
.from(users)
|
||||
.where(
|
||||
and(eq(users.email, email), eq(users.type, UserType.Internal))
|
||||
);
|
||||
|
||||
if (existing && existing.length > 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"A user with that email address already exists"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let newtId: string;
|
||||
let secret: string;
|
||||
let fullDomain: string;
|
||||
let resource: Resource;
|
||||
let completeSignUpLink: string;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
await trx.insert(users).values({
|
||||
userId: userId,
|
||||
type: UserType.Internal,
|
||||
username: email,
|
||||
email: email,
|
||||
passwordHash,
|
||||
dateCreated: moment().toISOString()
|
||||
});
|
||||
|
||||
// create user"s account
|
||||
await trx.insert(account).values({
|
||||
userId
|
||||
});
|
||||
});
|
||||
|
||||
const { success, error, org } = await createUserAccountOrg(
|
||||
userId,
|
||||
email
|
||||
);
|
||||
if (!success) {
|
||||
if (error) {
|
||||
throw new Error(error);
|
||||
}
|
||||
throw new Error("Failed to create user account and organization");
|
||||
}
|
||||
if (!org) {
|
||||
throw new Error("Failed to create user account and organization");
|
||||
}
|
||||
|
||||
const orgId = org.orgId;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
const token = generateRandomString(
|
||||
8,
|
||||
alphabet("0-9", "A-Z", "a-z")
|
||||
);
|
||||
|
||||
await trx
|
||||
.delete(passwordResetTokens)
|
||||
.where(eq(passwordResetTokens.userId, userId));
|
||||
|
||||
const tokenHash = await hashPassword(token);
|
||||
|
||||
await trx.insert(passwordResetTokens).values({
|
||||
userId: userId,
|
||||
email: email,
|
||||
tokenHash,
|
||||
expiresAt: createDate(new TimeSpan(7, "d")).getTime()
|
||||
});
|
||||
|
||||
// // Create the sandbox newt
|
||||
// const newClientAddress = await getNextAvailableClientSubnet(orgId);
|
||||
// if (!newClientAddress) {
|
||||
// throw new Error("No available subnet found");
|
||||
// }
|
||||
|
||||
// const clientAddress = newClientAddress.split("/")[0];
|
||||
|
||||
newtId = generateId(15);
|
||||
secret = generateId(48);
|
||||
|
||||
// Create the sandbox site
|
||||
const siteNiceId = await getUniqueSiteName(orgId);
|
||||
const siteName = `First Site`;
|
||||
|
||||
// pick a random exit node
|
||||
const exitNodesList = await listExitNodes(orgId);
|
||||
|
||||
// select a random exit node
|
||||
const randomExitNode =
|
||||
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
|
||||
|
||||
if (!randomExitNode) {
|
||||
throw new Error("No exit nodes available");
|
||||
}
|
||||
|
||||
const [newSite] = await trx
|
||||
.insert(sites)
|
||||
.values({
|
||||
orgId,
|
||||
exitNodeId: randomExitNode.exitNodeId,
|
||||
name: siteName,
|
||||
niceId: siteNiceId,
|
||||
// address: clientAddress,
|
||||
type: "newt",
|
||||
dockerSocketEnabled: true
|
||||
})
|
||||
.returning();
|
||||
|
||||
const siteId = newSite.siteId;
|
||||
|
||||
const adminRole = await trx
|
||||
.select()
|
||||
.from(roles)
|
||||
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
|
||||
.limit(1);
|
||||
|
||||
if (adminRole.length === 0) {
|
||||
throw new Error("Admin role not found");
|
||||
}
|
||||
|
||||
await trx.insert(roleSites).values({
|
||||
roleId: adminRole[0].roleId,
|
||||
siteId: newSite.siteId
|
||||
});
|
||||
|
||||
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
|
||||
// make sure the user can access the site
|
||||
await trx.insert(userSites).values({
|
||||
userId: req.user?.userId!,
|
||||
siteId: newSite.siteId
|
||||
});
|
||||
}
|
||||
|
||||
// add the peer to the exit node
|
||||
const secretHash = await hashPassword(secret!);
|
||||
|
||||
await trx.insert(newts).values({
|
||||
newtId: newtId!,
|
||||
secretHash,
|
||||
siteId: newSite.siteId,
|
||||
dateCreated: moment().toISOString()
|
||||
});
|
||||
|
||||
const [randomNamespace] = await trx
|
||||
.select()
|
||||
.from(domainNamespaces)
|
||||
.orderBy(sql`RANDOM()`)
|
||||
.limit(1);
|
||||
|
||||
if (!randomNamespace) {
|
||||
throw new Error("No domain namespace available");
|
||||
}
|
||||
|
||||
const [randomNamespaceDomain] = await trx
|
||||
.select()
|
||||
.from(domains)
|
||||
.where(eq(domains.domainId, randomNamespace.domainId))
|
||||
.limit(1);
|
||||
|
||||
if (!randomNamespaceDomain) {
|
||||
throw new Error("No domain found for the namespace");
|
||||
}
|
||||
|
||||
const resourceNiceId = await getUniqueResourceName(orgId);
|
||||
|
||||
// Create sandbox resource
|
||||
const subdomain = `${resourceNiceId}-${generateId(5)}`;
|
||||
fullDomain = `${subdomain}.${randomNamespaceDomain.baseDomain}`;
|
||||
|
||||
const resourceName = `First Resource`;
|
||||
|
||||
const newResource = await trx
|
||||
.insert(resources)
|
||||
.values({
|
||||
niceId: resourceNiceId,
|
||||
fullDomain,
|
||||
domainId: randomNamespaceDomain.domainId,
|
||||
orgId,
|
||||
name: resourceName,
|
||||
subdomain,
|
||||
http: true,
|
||||
protocol: "tcp",
|
||||
ssl: true,
|
||||
sso: false,
|
||||
emailWhitelistEnabled: enableWhitelist
|
||||
})
|
||||
.returning();
|
||||
|
||||
await trx.insert(roleResources).values({
|
||||
roleId: adminRole[0].roleId,
|
||||
resourceId: newResource[0].resourceId
|
||||
});
|
||||
|
||||
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
|
||||
// make sure the user can access the resource
|
||||
await trx.insert(userResources).values({
|
||||
userId: req.user?.userId!,
|
||||
resourceId: newResource[0].resourceId
|
||||
});
|
||||
}
|
||||
|
||||
resource = newResource[0];
|
||||
|
||||
// Create the sandbox target
|
||||
const { internalPort, targetIps } = await pickPort(siteId!, trx);
|
||||
|
||||
if (!internalPort) {
|
||||
throw new Error("No available internal port");
|
||||
}
|
||||
|
||||
const newTarget = await trx
|
||||
.insert(targets)
|
||||
.values({
|
||||
resourceId: resource.resourceId,
|
||||
siteId: siteId!,
|
||||
internalPort,
|
||||
ip,
|
||||
method,
|
||||
port,
|
||||
enabled: true
|
||||
})
|
||||
.returning();
|
||||
|
||||
const newHealthcheck = await trx
|
||||
.insert(targetHealthCheck)
|
||||
.values({
|
||||
targetId: newTarget[0].targetId,
|
||||
hcEnabled: false
|
||||
})
|
||||
.returning();
|
||||
|
||||
// add the new target to the targetIps array
|
||||
targetIps.push(`${ip}/32`);
|
||||
|
||||
const [newt] = await trx
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, siteId!))
|
||||
.limit(1);
|
||||
|
||||
await addTargets(
|
||||
newt.newtId,
|
||||
newTarget,
|
||||
newHealthcheck,
|
||||
resource.protocol
|
||||
);
|
||||
|
||||
// Set resource pincode if provided
|
||||
if (pincode) {
|
||||
await trx
|
||||
.delete(resourcePincode)
|
||||
.where(
|
||||
eq(resourcePincode.resourceId, resource!.resourceId)
|
||||
);
|
||||
|
||||
const pincodeHash = await hashPassword(pincode);
|
||||
|
||||
await trx.insert(resourcePincode).values({
|
||||
resourceId: resource!.resourceId,
|
||||
pincodeHash,
|
||||
digitLength: 6
|
||||
});
|
||||
}
|
||||
|
||||
// Set resource password if provided
|
||||
if (password) {
|
||||
await trx
|
||||
.delete(resourcePassword)
|
||||
.where(
|
||||
eq(resourcePassword.resourceId, resource!.resourceId)
|
||||
);
|
||||
|
||||
const passwordHash = await hashPassword(password);
|
||||
|
||||
await trx.insert(resourcePassword).values({
|
||||
resourceId: resource!.resourceId,
|
||||
passwordHash
|
||||
});
|
||||
}
|
||||
|
||||
// Set resource OTP if whitelist is enabled
|
||||
if (enableWhitelist) {
|
||||
await trx.insert(resourceWhitelist).values({
|
||||
email,
|
||||
resourceId: resource!.resourceId
|
||||
});
|
||||
}
|
||||
|
||||
completeSignUpLink = `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}&token=${token}`;
|
||||
|
||||
// Store token for email outside transaction
|
||||
await sendEmail(
|
||||
WelcomeQuickStart({
|
||||
username: email,
|
||||
link: completeSignUpLink,
|
||||
fallbackLink: `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}`,
|
||||
resourceMethod: method,
|
||||
resourceHostname: ip,
|
||||
resourcePort: port,
|
||||
resourceUrl: `https://${fullDomain}`,
|
||||
cliCommand: `newt --id ${newtId} --secret ${secret}`
|
||||
}),
|
||||
{
|
||||
to: email,
|
||||
from: config.getNoReplyEmail(),
|
||||
subject: `Access your Pangolin dashboard and resources`
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
return response<QuickStartResponse>(res, {
|
||||
data: {
|
||||
newtId: newtId!,
|
||||
newtSecret: secret!,
|
||||
resourceUrl: `https://${fullDomain!}`,
|
||||
completeSignUpLink: completeSignUpLink!
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Quick start completed successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
`Account already exists with that email. Email: ${email}. IP: ${req.ip}.`
|
||||
);
|
||||
}
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"A user with that email address already exists"
|
||||
)
|
||||
);
|
||||
} else {
|
||||
logger.error(e);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to do quick start"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const BACKEND_SECRET_KEY = "4f9b6000-5d1a-11f0-9de7-ff2cc032f501";
|
||||
|
||||
/**
|
||||
* Validates a token received from the frontend.
|
||||
* @param {string} token The validation token from the request.
|
||||
* @returns {{ isValid: boolean; message: string }} An object indicating if the token is valid.
|
||||
*/
|
||||
const validateTokenOnApi = (
|
||||
token: string
|
||||
): { isValid: boolean; message: string } => {
|
||||
if (token === DEMO_UBO_KEY) {
|
||||
// Special case for demo UBO key
|
||||
return { isValid: true, message: "Demo UBO key is valid." };
|
||||
}
|
||||
|
||||
if (!token) {
|
||||
return { isValid: false, message: "Error: No token provided." };
|
||||
}
|
||||
|
||||
try {
|
||||
// 1. Decode the base64 string
|
||||
const decodedB64 = atob(token);
|
||||
|
||||
// 2. Reverse the character code manipulation
|
||||
const deobfuscated = decodedB64
|
||||
.split("")
|
||||
.map((char) => String.fromCharCode(char.charCodeAt(0) - 5)) // Reverse the shift
|
||||
.join("");
|
||||
|
||||
// 3. Split the data to get the original secret and timestamp
|
||||
const parts = deobfuscated.split("|");
|
||||
if (parts.length !== 2) {
|
||||
throw new Error("Invalid token format.");
|
||||
}
|
||||
const receivedKey = parts[0];
|
||||
const tokenTimestamp = parseInt(parts[1], 10);
|
||||
|
||||
// 4. Check if the secret key matches
|
||||
if (receivedKey !== BACKEND_SECRET_KEY) {
|
||||
return { isValid: false, message: "Invalid token: Key mismatch." };
|
||||
}
|
||||
|
||||
// 5. Check if the timestamp is recent (e.g., within 30 seconds) to prevent replay attacks
|
||||
const now = Date.now();
|
||||
const timeDifference = now - tokenTimestamp;
|
||||
|
||||
if (timeDifference > 30000) {
|
||||
// 30 seconds
|
||||
return { isValid: false, message: "Invalid token: Expired." };
|
||||
}
|
||||
|
||||
if (timeDifference < 0) {
|
||||
// Timestamp is in the future
|
||||
return {
|
||||
isValid: false,
|
||||
message: "Invalid token: Timestamp is in the future."
|
||||
};
|
||||
}
|
||||
|
||||
// If all checks pass, the token is valid
|
||||
return { isValid: true, message: "Token is valid!" };
|
||||
} catch (error) {
|
||||
// This will catch errors from atob (if not valid base64) or other issues.
|
||||
return {
|
||||
isValid: false,
|
||||
message: `Error: ${(error as Error).message}`
|
||||
};
|
||||
}
|
||||
};
|
||||
268
server/private/routers/billing/changeTier.ts
Normal file
268
server/private/routers/billing/changeTier.ts
Normal file
@@ -0,0 +1,268 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { customers, db, subscriptions, subscriptionItems } from "@server/db";
|
||||
import { eq, and, or } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import stripe from "#private/lib/stripe";
|
||||
import {
|
||||
getTier1FeaturePriceSet,
|
||||
getTier3FeaturePriceSet,
|
||||
getTier2FeaturePriceSet,
|
||||
FeatureId,
|
||||
type FeaturePriceSet
|
||||
} from "@server/lib/billing";
|
||||
import { getLineItems } from "@server/lib/billing/getLineItems";
|
||||
|
||||
const changeTierSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const changeTierBodySchema = z.strictObject({
|
||||
tier: z.enum(["tier1", "tier2", "tier3"])
|
||||
});
|
||||
|
||||
export async function changeTier(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = changeTierSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
const parsedBody = changeTierBodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { tier } = parsedBody.data;
|
||||
|
||||
// Get the customer for this org
|
||||
const [customer] = await db
|
||||
.select()
|
||||
.from(customers)
|
||||
.where(eq(customers.orgId, orgId))
|
||||
.limit(1);
|
||||
|
||||
if (!customer) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"No customer found for this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Get the active subscription for this customer
|
||||
const [subscription] = await db
|
||||
.select()
|
||||
.from(subscriptions)
|
||||
.where(
|
||||
and(
|
||||
eq(subscriptions.customerId, customer.customerId),
|
||||
eq(subscriptions.status, "active"),
|
||||
or(
|
||||
eq(subscriptions.type, "tier1"),
|
||||
eq(subscriptions.type, "tier2"),
|
||||
eq(subscriptions.type, "tier3")
|
||||
)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!subscription) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"No active subscription found for this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Get the target tier's price set
|
||||
let targetPriceSet: FeaturePriceSet;
|
||||
if (tier === "tier1") {
|
||||
targetPriceSet = getTier1FeaturePriceSet();
|
||||
} else if (tier === "tier2") {
|
||||
targetPriceSet = getTier2FeaturePriceSet();
|
||||
} else if (tier === "tier3") {
|
||||
targetPriceSet = getTier3FeaturePriceSet();
|
||||
} else {
|
||||
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid tier"));
|
||||
}
|
||||
|
||||
// Get current subscription items from our database
|
||||
const currentItems = await db
|
||||
.select()
|
||||
.from(subscriptionItems)
|
||||
.where(
|
||||
eq(
|
||||
subscriptionItems.subscriptionId,
|
||||
subscription.subscriptionId
|
||||
)
|
||||
);
|
||||
|
||||
if (currentItems.length === 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"No subscription items found"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Retrieve the full subscription from Stripe to get item IDs
|
||||
const stripeSubscription = await stripe!.subscriptions.retrieve(
|
||||
subscription.subscriptionId
|
||||
);
|
||||
|
||||
// Determine if we're switching between different products
|
||||
// tier1 uses TIER1 product, tier2/tier3 use USERS product
|
||||
const currentTier = subscription.type;
|
||||
const switchingProducts =
|
||||
(currentTier === "tier1" &&
|
||||
(tier === "tier2" || tier === "tier3")) ||
|
||||
((currentTier === "tier2" || currentTier === "tier3") &&
|
||||
tier === "tier1");
|
||||
|
||||
let updatedSubscription;
|
||||
|
||||
if (switchingProducts) {
|
||||
// When switching between different products, we need to:
|
||||
// 1. Delete old subscription items
|
||||
// 2. Add new subscription items
|
||||
logger.info(
|
||||
`Switching products from ${currentTier} to ${tier} for subscription ${subscription.subscriptionId}`
|
||||
);
|
||||
|
||||
// Build array to delete all existing items and add new ones
|
||||
const itemsToUpdate: any[] = [];
|
||||
|
||||
// Mark all existing items for deletion
|
||||
for (const stripeItem of stripeSubscription.items.data) {
|
||||
itemsToUpdate.push({
|
||||
id: stripeItem.id,
|
||||
deleted: true
|
||||
});
|
||||
}
|
||||
|
||||
// Add new items for the target tier
|
||||
const newLineItems = await getLineItems(targetPriceSet, orgId);
|
||||
for (const lineItem of newLineItems) {
|
||||
itemsToUpdate.push(lineItem);
|
||||
}
|
||||
|
||||
updatedSubscription = await stripe!.subscriptions.update(
|
||||
subscription.subscriptionId,
|
||||
{
|
||||
items: itemsToUpdate,
|
||||
proration_behavior: "create_prorations"
|
||||
}
|
||||
);
|
||||
} else {
|
||||
// Same product, different price tier (tier2 <-> tier3)
|
||||
// We can simply update the price
|
||||
logger.info(
|
||||
`Updating price from ${currentTier} to ${tier} for subscription ${subscription.subscriptionId}`
|
||||
);
|
||||
|
||||
const itemsToUpdate = stripeSubscription.items.data.map(
|
||||
(stripeItem) => {
|
||||
// Find the corresponding item in our database
|
||||
const dbItem = currentItems.find(
|
||||
(item) => item.priceId === stripeItem.price.id
|
||||
);
|
||||
|
||||
if (!dbItem) {
|
||||
// Keep the existing item unchanged if we can't find it
|
||||
return {
|
||||
id: stripeItem.id,
|
||||
price: stripeItem.price.id,
|
||||
quantity: stripeItem.quantity
|
||||
};
|
||||
}
|
||||
|
||||
// Map to the corresponding feature in the new tier
|
||||
const newPriceId = targetPriceSet[FeatureId.USERS];
|
||||
|
||||
if (newPriceId) {
|
||||
return {
|
||||
id: stripeItem.id,
|
||||
price: newPriceId,
|
||||
quantity: stripeItem.quantity
|
||||
};
|
||||
}
|
||||
|
||||
// If no mapping found, keep existing
|
||||
return {
|
||||
id: stripeItem.id,
|
||||
price: stripeItem.price.id,
|
||||
quantity: stripeItem.quantity
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
updatedSubscription = await stripe!.subscriptions.update(
|
||||
subscription.subscriptionId,
|
||||
{
|
||||
items: itemsToUpdate,
|
||||
proration_behavior: "create_prorations"
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Successfully changed tier to ${tier} for org ${orgId}, subscription ${subscription.subscriptionId}`
|
||||
);
|
||||
|
||||
return response<{ subscriptionId: string; newTier: string }>(res, {
|
||||
data: {
|
||||
subscriptionId: updatedSubscription.id,
|
||||
newTier: tier
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Tier change successful",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error("Error changing tier:", error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"An error occurred while changing tier"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -22,14 +22,23 @@ import logger from "@server/logger";
|
||||
import config from "@server/lib/config";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import stripe from "#private/lib/stripe";
|
||||
import { getLineItems, getStandardFeaturePriceSet } from "@server/lib/billing";
|
||||
import { getTierPriceSet, TierId } from "@server/lib/billing/tiers";
|
||||
import {
|
||||
getTier1FeaturePriceSet,
|
||||
getTier3FeaturePriceSet,
|
||||
getTier2FeaturePriceSet
|
||||
} from "@server/lib/billing";
|
||||
import { getLineItems } from "@server/lib/billing/getLineItems";
|
||||
import Stripe from "stripe";
|
||||
|
||||
const createCheckoutSessionSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export async function createCheckoutSessionSAAS(
|
||||
const createCheckoutSessionBodySchema = z.strictObject({
|
||||
tier: z.enum(["tier1", "tier2", "tier3"])
|
||||
});
|
||||
|
||||
export async function createCheckoutSession(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
@@ -47,6 +56,18 @@ export async function createCheckoutSessionSAAS(
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
const parsedBody = createCheckoutSessionBodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { tier } = parsedBody.data;
|
||||
|
||||
// check if we already have a customer for this org
|
||||
const [customer] = await db
|
||||
.select()
|
||||
@@ -65,20 +86,26 @@ export async function createCheckoutSessionSAAS(
|
||||
);
|
||||
}
|
||||
|
||||
const standardTierPrice = getTierPriceSet()[TierId.STANDARD];
|
||||
let lineItems: Stripe.Checkout.SessionCreateParams.LineItem[];
|
||||
if (tier === "tier1") {
|
||||
lineItems = await getLineItems(getTier1FeaturePriceSet(), orgId);
|
||||
} else if (tier === "tier2") {
|
||||
lineItems = await getLineItems(getTier2FeaturePriceSet(), orgId);
|
||||
} else if (tier === "tier3") {
|
||||
lineItems = await getLineItems(getTier3FeaturePriceSet(), orgId);
|
||||
} else {
|
||||
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid plan"));
|
||||
}
|
||||
|
||||
logger.debug(`Line items: ${JSON.stringify(lineItems)}`);
|
||||
|
||||
const session = await stripe!.checkout.sessions.create({
|
||||
client_reference_id: orgId, // So we can look it up the org later on the webhook
|
||||
billing_address_collection: "required",
|
||||
line_items: [
|
||||
{
|
||||
price: standardTierPrice, // Use the standard tier
|
||||
quantity: 1
|
||||
},
|
||||
...getLineItems(getStandardFeaturePriceSet())
|
||||
], // Start with the standard feature set that matches the free limits
|
||||
line_items: lineItems,
|
||||
customer: customer.customerId,
|
||||
mode: "subscription",
|
||||
allow_promotion_codes: true,
|
||||
success_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing?success=true&session_id={CHECKOUT_SESSION_ID}`,
|
||||
cancel_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing?canceled=true`
|
||||
});
|
||||
410
server/private/routers/billing/featureLifecycle.ts
Normal file
410
server/private/routers/billing/featureLifecycle.ts
Normal file
@@ -0,0 +1,410 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { SubscriptionType } from "./hooks/getSubType";
|
||||
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
import { Tier } from "@server/types/Tiers";
|
||||
import logger from "@server/logger";
|
||||
import { db, idp, idpOrg, loginPage, loginPageBranding, loginPageBrandingOrg, loginPageOrg, orgs, resources, roles } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
/**
|
||||
* Get the maximum allowed retention days for a given tier
|
||||
* Returns null for enterprise tier (unlimited)
|
||||
*/
|
||||
function getMaxRetentionDaysForTier(tier: Tier | null): number | null {
|
||||
if (!tier) {
|
||||
return 3; // Free tier
|
||||
}
|
||||
|
||||
switch (tier) {
|
||||
case "tier1":
|
||||
return 7;
|
||||
case "tier2":
|
||||
return 30;
|
||||
case "tier3":
|
||||
return 90;
|
||||
case "enterprise":
|
||||
return null; // No limit
|
||||
default:
|
||||
return 3; // Default to free tier limit
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cap retention days to the maximum allowed for the given tier
|
||||
*/
|
||||
async function capRetentionDays(
|
||||
orgId: string,
|
||||
tier: Tier | null
|
||||
): Promise<void> {
|
||||
const maxRetentionDays = getMaxRetentionDaysForTier(tier);
|
||||
|
||||
// If there's no limit (enterprise tier), no capping needed
|
||||
if (maxRetentionDays === null) {
|
||||
logger.debug(
|
||||
`No retention day limit for org ${orgId} on tier ${tier || "free"}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Get current org settings
|
||||
const [org] = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, orgId));
|
||||
|
||||
if (!org) {
|
||||
logger.warn(`Org ${orgId} not found when capping retention days`);
|
||||
return;
|
||||
}
|
||||
|
||||
const updates: Partial<typeof orgs.$inferInsert> = {};
|
||||
let needsUpdate = false;
|
||||
|
||||
// Cap request log retention if it exceeds the limit
|
||||
if (
|
||||
org.settingsLogRetentionDaysRequest !== null &&
|
||||
org.settingsLogRetentionDaysRequest > maxRetentionDays
|
||||
) {
|
||||
updates.settingsLogRetentionDaysRequest = maxRetentionDays;
|
||||
needsUpdate = true;
|
||||
logger.info(
|
||||
`Capping request log retention from ${org.settingsLogRetentionDaysRequest} to ${maxRetentionDays} days for org ${orgId}`
|
||||
);
|
||||
}
|
||||
|
||||
// Cap access log retention if it exceeds the limit
|
||||
if (
|
||||
org.settingsLogRetentionDaysAccess !== null &&
|
||||
org.settingsLogRetentionDaysAccess > maxRetentionDays
|
||||
) {
|
||||
updates.settingsLogRetentionDaysAccess = maxRetentionDays;
|
||||
needsUpdate = true;
|
||||
logger.info(
|
||||
`Capping access log retention from ${org.settingsLogRetentionDaysAccess} to ${maxRetentionDays} days for org ${orgId}`
|
||||
);
|
||||
}
|
||||
|
||||
// Cap action log retention if it exceeds the limit
|
||||
if (
|
||||
org.settingsLogRetentionDaysAction !== null &&
|
||||
org.settingsLogRetentionDaysAction > maxRetentionDays
|
||||
) {
|
||||
updates.settingsLogRetentionDaysAction = maxRetentionDays;
|
||||
needsUpdate = true;
|
||||
logger.info(
|
||||
`Capping action log retention from ${org.settingsLogRetentionDaysAction} to ${maxRetentionDays} days for org ${orgId}`
|
||||
);
|
||||
}
|
||||
|
||||
// Apply updates if needed
|
||||
if (needsUpdate) {
|
||||
await db
|
||||
.update(orgs)
|
||||
.set(updates)
|
||||
.where(eq(orgs.orgId, orgId));
|
||||
|
||||
logger.info(
|
||||
`Successfully capped retention days for org ${orgId} to max ${maxRetentionDays} days`
|
||||
);
|
||||
} else {
|
||||
logger.debug(
|
||||
`No retention day capping needed for org ${orgId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export async function handleTierChange(
|
||||
orgId: string,
|
||||
newTier: SubscriptionType | null,
|
||||
previousTier?: SubscriptionType | null
|
||||
): Promise<void> {
|
||||
logger.info(
|
||||
`Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}`
|
||||
);
|
||||
|
||||
// License subscriptions are handled separately and don't use the tier matrix
|
||||
if (newTier === "license") {
|
||||
logger.debug(
|
||||
`New tier is license for org ${orgId}, no feature lifecycle handling needed`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// If newTier is null, treat as free tier - disable all features
|
||||
if (newTier === null) {
|
||||
logger.info(
|
||||
`Org ${orgId} is reverting to free tier, disabling all paid features`
|
||||
);
|
||||
// Cap retention days to free tier limits
|
||||
await capRetentionDays(orgId, null);
|
||||
|
||||
// Disable all features in the tier matrix
|
||||
for (const [featureKey] of Object.entries(tierMatrix)) {
|
||||
const feature = featureKey as TierFeature;
|
||||
logger.info(
|
||||
`Feature ${feature} is not available in free tier for org ${orgId}. Disabling...`
|
||||
);
|
||||
await disableFeature(orgId, feature);
|
||||
}
|
||||
logger.info(
|
||||
`Completed free tier feature lifecycle handling for org ${orgId}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Get the tier (cast as Tier since we've ruled out "license" and null)
|
||||
const tier = newTier as Tier;
|
||||
|
||||
// Cap retention days to the new tier's limits
|
||||
await capRetentionDays(orgId, tier);
|
||||
|
||||
// Check each feature in the tier matrix
|
||||
for (const [featureKey, allowedTiers] of Object.entries(tierMatrix)) {
|
||||
const feature = featureKey as TierFeature;
|
||||
const isFeatureAvailable = allowedTiers.includes(tier);
|
||||
|
||||
if (!isFeatureAvailable) {
|
||||
logger.info(
|
||||
`Feature ${feature} is not available in tier ${tier} for org ${orgId}. Disabling...`
|
||||
);
|
||||
await disableFeature(orgId, feature);
|
||||
} else {
|
||||
logger.debug(
|
||||
`Feature ${feature} is available in tier ${tier} for org ${orgId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Completed tier change feature lifecycle handling for org ${orgId}`
|
||||
);
|
||||
}
|
||||
|
||||
async function disableFeature(
|
||||
orgId: string,
|
||||
feature: TierFeature
|
||||
): Promise<void> {
|
||||
try {
|
||||
switch (feature) {
|
||||
case TierFeature.OrgOidc:
|
||||
await disableOrgOidc(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.LoginPageDomain:
|
||||
await disableLoginPageDomain(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.DeviceApprovals:
|
||||
await disableDeviceApprovals(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.LoginPageBranding:
|
||||
await disableLoginPageBranding(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.LogExport:
|
||||
await disableLogExport(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.AccessLogs:
|
||||
await disableAccessLogs(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.ActionLogs:
|
||||
await disableActionLogs(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.RotateCredentials:
|
||||
await disableRotateCredentials(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.MaintencePage:
|
||||
await disableMaintencePage(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.DevicePosture:
|
||||
await disableDevicePosture(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.TwoFactorEnforcement:
|
||||
await disableTwoFactorEnforcement(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.SessionDurationPolicies:
|
||||
await disableSessionDurationPolicies(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.PasswordExpirationPolicies:
|
||||
await disablePasswordExpirationPolicies(orgId);
|
||||
break;
|
||||
|
||||
case TierFeature.AutoProvisioning:
|
||||
await disableAutoProvisioning(orgId);
|
||||
break;
|
||||
|
||||
default:
|
||||
logger.warn(
|
||||
`Unknown feature ${feature} for org ${orgId}, skipping`
|
||||
);
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Successfully disabled feature ${feature} for org ${orgId}`
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error disabling feature ${feature} for org ${orgId}:`,
|
||||
error
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async function disableOrgOidc(orgId: string): Promise<void> {}
|
||||
|
||||
async function disableDeviceApprovals(orgId: string): Promise<void> {
|
||||
await db
|
||||
.update(roles)
|
||||
.set({ requireDeviceApproval: false })
|
||||
.where(eq(roles.orgId, orgId));
|
||||
|
||||
logger.info(`Disabled device approvals on all roles for org ${orgId}`);
|
||||
}
|
||||
|
||||
async function disableLoginPageBranding(orgId: string): Promise<void> {
|
||||
const [existingBranding] = await db
|
||||
.select()
|
||||
.from(loginPageBrandingOrg)
|
||||
.where(eq(loginPageBrandingOrg.orgId, orgId));
|
||||
|
||||
if (existingBranding) {
|
||||
await db
|
||||
.delete(loginPageBranding)
|
||||
.where(
|
||||
eq(
|
||||
loginPageBranding.loginPageBrandingId,
|
||||
existingBranding.loginPageBrandingId
|
||||
)
|
||||
);
|
||||
|
||||
logger.info(`Disabled login page branding for org ${orgId}`);
|
||||
}
|
||||
}
|
||||
|
||||
async function disableLoginPageDomain(orgId: string): Promise<void> {
|
||||
const [existingLoginPage] = await db
|
||||
.select()
|
||||
.from(loginPageOrg)
|
||||
.where(eq(loginPageOrg.orgId, orgId))
|
||||
.innerJoin(
|
||||
loginPage,
|
||||
eq(loginPage.loginPageId, loginPageOrg.loginPageId)
|
||||
);
|
||||
|
||||
if (existingLoginPage) {
|
||||
await db
|
||||
.delete(loginPageOrg)
|
||||
.where(eq(loginPageOrg.orgId, orgId));
|
||||
|
||||
await db
|
||||
.delete(loginPage)
|
||||
.where(
|
||||
eq(
|
||||
loginPage.loginPageId,
|
||||
existingLoginPage.loginPageOrg.loginPageId
|
||||
)
|
||||
);
|
||||
|
||||
logger.info(`Disabled login page domain for org ${orgId}`);
|
||||
}
|
||||
}
|
||||
|
||||
async function disableLogExport(orgId: string): Promise<void> {}
|
||||
|
||||
async function disableAccessLogs(orgId: string): Promise<void> {
|
||||
await db
|
||||
.update(orgs)
|
||||
.set({ settingsLogRetentionDaysAccess: 0 })
|
||||
.where(eq(orgs.orgId, orgId));
|
||||
|
||||
logger.info(`Disabled access logs for org ${orgId}`);
|
||||
}
|
||||
|
||||
async function disableActionLogs(orgId: string): Promise<void> {
|
||||
await db
|
||||
.update(orgs)
|
||||
.set({ settingsLogRetentionDaysAction: 0 })
|
||||
.where(eq(orgs.orgId, orgId));
|
||||
|
||||
logger.info(`Disabled action logs for org ${orgId}`);
|
||||
}
|
||||
|
||||
async function disableRotateCredentials(orgId: string): Promise<void> {}
|
||||
|
||||
async function disableMaintencePage(orgId: string): Promise<void> {
|
||||
await db
|
||||
.update(resources)
|
||||
.set({
|
||||
maintenanceModeEnabled: false
|
||||
})
|
||||
.where(eq(resources.orgId, orgId));
|
||||
|
||||
logger.info(`Disabled maintenance page on all resources for org ${orgId}`);
|
||||
}
|
||||
|
||||
async function disableDevicePosture(orgId: string): Promise<void> {}
|
||||
|
||||
async function disableTwoFactorEnforcement(orgId: string): Promise<void> {
|
||||
await db
|
||||
.update(orgs)
|
||||
.set({ requireTwoFactor: false })
|
||||
.where(eq(orgs.orgId, orgId));
|
||||
|
||||
logger.info(`Disabled two-factor enforcement for org ${orgId}`);
|
||||
}
|
||||
|
||||
async function disableSessionDurationPolicies(orgId: string): Promise<void> {
|
||||
await db
|
||||
.update(orgs)
|
||||
.set({ maxSessionLengthHours: null })
|
||||
.where(eq(orgs.orgId, orgId));
|
||||
|
||||
logger.info(`Disabled session duration policies for org ${orgId}`);
|
||||
}
|
||||
|
||||
async function disablePasswordExpirationPolicies(orgId: string): Promise<void> {
|
||||
await db
|
||||
.update(orgs)
|
||||
.set({ passwordExpiryDays: null })
|
||||
.where(eq(orgs.orgId, orgId));
|
||||
|
||||
logger.info(`Disabled password expiration policies for org ${orgId}`);
|
||||
}
|
||||
|
||||
async function disableAutoProvisioning(orgId: string): Promise<void> {
|
||||
// Get all IDP IDs for this org through the idpOrg join table
|
||||
const orgIdps = await db
|
||||
.select({ idpId: idpOrg.idpId })
|
||||
.from(idpOrg)
|
||||
.where(eq(idpOrg.orgId, orgId));
|
||||
|
||||
// Update autoProvision to false for all IDPs in this org
|
||||
for (const { idpId } of orgIdps) {
|
||||
await db
|
||||
.update(idp)
|
||||
.set({ autoProvision: false })
|
||||
.where(eq(idp.idpId, idpId));
|
||||
}
|
||||
}
|
||||
@@ -23,6 +23,8 @@ import logger from "@server/logger";
|
||||
import { fromZodError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { GetOrgSubscriptionResponse } from "@server/routers/billing/types";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import { build } from "@server/build";
|
||||
|
||||
// Import tables for billing
|
||||
import {
|
||||
@@ -70,9 +72,19 @@ export async function getOrgSubscriptions(
|
||||
throw err;
|
||||
}
|
||||
|
||||
let limitsExceeded = false;
|
||||
if (build === "saas") {
|
||||
try {
|
||||
limitsExceeded = await usageService.checkLimitSet(orgId);
|
||||
} catch (err) {
|
||||
logger.error("Error checking limits for org %s: %s", orgId, err);
|
||||
}
|
||||
}
|
||||
|
||||
return response<GetOrgSubscriptionResponse>(res, {
|
||||
data: {
|
||||
subscriptions
|
||||
subscriptions,
|
||||
...(build === "saas" ? { limitsExceeded } : {})
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
|
||||
@@ -78,16 +78,10 @@ export async function getOrgUsage(
|
||||
// Get usage for org
|
||||
const usageData = [];
|
||||
|
||||
const siteUptime = await usageService.getUsage(
|
||||
orgId,
|
||||
FeatureId.SITE_UPTIME
|
||||
);
|
||||
const users = await usageService.getUsageDaily(orgId, FeatureId.USERS);
|
||||
const domains = await usageService.getUsageDaily(
|
||||
orgId,
|
||||
FeatureId.DOMAINS
|
||||
);
|
||||
const remoteExitNodes = await usageService.getUsageDaily(
|
||||
const sites = await usageService.getUsage(orgId, FeatureId.SITES);
|
||||
const users = await usageService.getUsage(orgId, FeatureId.USERS);
|
||||
const domains = await usageService.getUsage(orgId, FeatureId.DOMAINS);
|
||||
const remoteExitNodes = await usageService.getUsage(
|
||||
orgId,
|
||||
FeatureId.REMOTE_EXIT_NODES
|
||||
);
|
||||
@@ -96,8 +90,8 @@ export async function getOrgUsage(
|
||||
FeatureId.EGRESS_DATA_MB
|
||||
);
|
||||
|
||||
if (siteUptime) {
|
||||
usageData.push(siteUptime);
|
||||
if (sites) {
|
||||
usageData.push(sites);
|
||||
}
|
||||
if (users) {
|
||||
usageData.push(users);
|
||||
|
||||
@@ -1,35 +1,62 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import {
|
||||
getLicensePriceSet,
|
||||
} from "@server/lib/billing/licenses";
|
||||
import {
|
||||
getTierPriceSet,
|
||||
} from "@server/lib/billing/tiers";
|
||||
getTier1FeaturePriceSet,
|
||||
getTier2FeaturePriceSet,
|
||||
getTier3FeaturePriceSet,
|
||||
} from "@server/lib/billing/features";
|
||||
import Stripe from "stripe";
|
||||
import { Tier } from "@server/types/Tiers";
|
||||
|
||||
export function getSubType(fullSubscription: Stripe.Response<Stripe.Subscription>): "saas" | "license" {
|
||||
export type SubscriptionType = Tier | "license";
|
||||
|
||||
export function getSubType(fullSubscription: Stripe.Response<Stripe.Subscription>): SubscriptionType | null {
|
||||
// Determine subscription type by checking subscription items
|
||||
let type: "saas" | "license" = "saas";
|
||||
if (Array.isArray(fullSubscription.items?.data)) {
|
||||
for (const item of fullSubscription.items.data) {
|
||||
const priceId = item.price.id;
|
||||
if (!Array.isArray(fullSubscription.items?.data) || fullSubscription.items.data.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Check if price ID matches any license price
|
||||
const licensePrices = Object.values(getLicensePriceSet());
|
||||
for (const item of fullSubscription.items.data) {
|
||||
const priceId = item.price.id;
|
||||
|
||||
if (licensePrices.includes(priceId)) {
|
||||
type = "license";
|
||||
break;
|
||||
}
|
||||
// Check if price ID matches any license price
|
||||
const licensePrices = Object.values(getLicensePriceSet());
|
||||
if (licensePrices.includes(priceId)) {
|
||||
return "license";
|
||||
}
|
||||
|
||||
// Check if price ID matches any tier price (saas)
|
||||
const tierPrices = Object.values(getTierPriceSet());
|
||||
// Check if price ID matches home lab tier
|
||||
const homeLabPrices = Object.values(getTier1FeaturePriceSet());
|
||||
if (homeLabPrices.includes(priceId)) {
|
||||
return "tier1";
|
||||
}
|
||||
|
||||
if (tierPrices.includes(priceId)) {
|
||||
type = "saas";
|
||||
break;
|
||||
}
|
||||
// Check if price ID matches tier2 tier
|
||||
const tier2Prices = Object.values(getTier2FeaturePriceSet());
|
||||
if (tier2Prices.includes(priceId)) {
|
||||
return "tier2";
|
||||
}
|
||||
|
||||
// Check if price ID matches tier3 tier
|
||||
const tier3Prices = Object.values(getTier3FeaturePriceSet());
|
||||
if (tier3Prices.includes(priceId)) {
|
||||
return "tier3";
|
||||
}
|
||||
}
|
||||
|
||||
return type;
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -31,6 +31,8 @@ import { getLicensePriceSet, LicenseId } from "@server/lib/billing/licenses";
|
||||
import { sendEmail } from "@server/emails";
|
||||
import EnterpriseEditionKeyGenerated from "@server/emails/templates/EnterpriseEditionKeyGenerated";
|
||||
import config from "@server/lib/config";
|
||||
import { getFeatureIdByPriceId } from "@server/lib/billing/features";
|
||||
import { handleTierChange } from "../featureLifecycle";
|
||||
|
||||
export async function handleSubscriptionCreated(
|
||||
subscription: Stripe.Subscription
|
||||
@@ -59,6 +61,8 @@ export async function handleSubscriptionCreated(
|
||||
return;
|
||||
}
|
||||
|
||||
const type = getSubType(fullSubscription);
|
||||
|
||||
const newSubscription = {
|
||||
subscriptionId: subscription.id,
|
||||
customerId: subscription.customer as string,
|
||||
@@ -66,7 +70,9 @@ export async function handleSubscriptionCreated(
|
||||
canceledAt: subscription.canceled_at
|
||||
? subscription.canceled_at
|
||||
: null,
|
||||
createdAt: subscription.created
|
||||
createdAt: subscription.created,
|
||||
type: type,
|
||||
version: 1 // we are hardcoding the initial version when the subscription is created, and then we will increment it on every update
|
||||
};
|
||||
|
||||
await db.insert(subscriptions).values(newSubscription);
|
||||
@@ -87,10 +93,15 @@ export async function handleSubscriptionCreated(
|
||||
name = product.name || null;
|
||||
}
|
||||
|
||||
// Get the feature ID from the price ID
|
||||
const featureId = getFeatureIdByPriceId(item.price.id);
|
||||
|
||||
return {
|
||||
stripeSubscriptionItemId: item.id,
|
||||
subscriptionId: subscription.id,
|
||||
planId: item.plan.id,
|
||||
priceId: item.price.id,
|
||||
featureId: featureId || null,
|
||||
meterId: item.plan.meter,
|
||||
unitAmount: item.price.unit_amount || 0,
|
||||
currentPeriodStart: item.current_period_start,
|
||||
@@ -129,17 +140,23 @@ export async function handleSubscriptionCreated(
|
||||
return;
|
||||
}
|
||||
|
||||
const type = getSubType(fullSubscription);
|
||||
if (type === "saas") {
|
||||
if (type === "tier1" || type === "tier2" || type === "tier3") {
|
||||
logger.debug(
|
||||
`Handling SAAS subscription lifecycle for org ${customer.orgId}`
|
||||
`Handling SAAS subscription lifecycle for org ${customer.orgId} with type ${type}`
|
||||
);
|
||||
// we only need to handle the limit lifecycle for saas subscriptions not for the licenses
|
||||
await handleSubscriptionLifesycle(
|
||||
customer.orgId,
|
||||
subscription.status
|
||||
subscription.status,
|
||||
type
|
||||
);
|
||||
|
||||
// Handle initial tier setup - disable features not available in this tier
|
||||
logger.info(
|
||||
`Setting up initial tier features for org ${customer.orgId} with type ${type}`
|
||||
);
|
||||
await handleTierChange(customer.orgId, type);
|
||||
|
||||
const [orgUserRes] = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
|
||||
@@ -27,6 +27,7 @@ import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
|
||||
import { getSubType } from "./getSubType";
|
||||
import stripe from "#private/lib/stripe";
|
||||
import privateConfig from "#private/lib/config";
|
||||
import { handleTierChange } from "../featureLifecycle";
|
||||
|
||||
export async function handleSubscriptionDeleted(
|
||||
subscription: Stripe.Subscription
|
||||
@@ -76,16 +77,23 @@ export async function handleSubscriptionDeleted(
|
||||
}
|
||||
|
||||
const type = getSubType(fullSubscription);
|
||||
if (type === "saas") {
|
||||
if (type == "tier1" || type == "tier2" || type == "tier3") {
|
||||
logger.debug(
|
||||
`Handling SaaS subscription deletion for orgId ${customer.orgId} and subscription ID ${subscription.id}`
|
||||
);
|
||||
|
||||
await handleSubscriptionLifesycle(
|
||||
customer.orgId,
|
||||
subscription.status
|
||||
subscription.status,
|
||||
type
|
||||
);
|
||||
|
||||
// Handle feature lifecycle for cancellation - disable all tier-specific features
|
||||
logger.info(
|
||||
`Disabling tier-specific features for org ${customer.orgId} due to subscription deletion`
|
||||
);
|
||||
await handleTierChange(customer.orgId, null, type);
|
||||
|
||||
const [orgUserRes] = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
|
||||
@@ -23,11 +23,12 @@ import {
|
||||
} from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { getFeatureIdByMetricId } from "@server/lib/billing/features";
|
||||
import { getFeatureIdByMetricId, getFeatureIdByPriceId } from "@server/lib/billing/features";
|
||||
import stripe from "#private/lib/stripe";
|
||||
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
|
||||
import { getSubType } from "./getSubType";
|
||||
import { getSubType, SubscriptionType } from "./getSubType";
|
||||
import privateConfig from "#private/lib/config";
|
||||
import { handleTierChange } from "../featureLifecycle";
|
||||
|
||||
export async function handleSubscriptionUpdated(
|
||||
subscription: Stripe.Subscription,
|
||||
@@ -64,6 +65,9 @@ export async function handleSubscriptionUpdated(
|
||||
.where(eq(customers.customerId, subscription.customer as string))
|
||||
.limit(1);
|
||||
|
||||
const type = getSubType(fullSubscription);
|
||||
const previousType = existingSubscription.type as SubscriptionType | null;
|
||||
|
||||
await db
|
||||
.update(subscriptions)
|
||||
.set({
|
||||
@@ -72,25 +76,55 @@ export async function handleSubscriptionUpdated(
|
||||
? subscription.canceled_at
|
||||
: null,
|
||||
updatedAt: Math.floor(Date.now() / 1000),
|
||||
billingCycleAnchor: subscription.billing_cycle_anchor
|
||||
billingCycleAnchor: subscription.billing_cycle_anchor,
|
||||
type: type
|
||||
})
|
||||
.where(eq(subscriptions.subscriptionId, subscription.id));
|
||||
|
||||
// Handle tier change if the subscription type changed
|
||||
if (type && type !== previousType) {
|
||||
logger.info(
|
||||
`Tier change detected for org ${customer.orgId}: ${previousType} -> ${type}`
|
||||
);
|
||||
await handleTierChange(customer.orgId, type, previousType ?? undefined);
|
||||
}
|
||||
|
||||
// Upsert subscription items
|
||||
if (Array.isArray(fullSubscription.items?.data)) {
|
||||
const itemsToUpsert = fullSubscription.items.data.map((item) => ({
|
||||
subscriptionId: subscription.id,
|
||||
planId: item.plan.id,
|
||||
priceId: item.price.id,
|
||||
meterId: item.plan.meter,
|
||||
unitAmount: item.price.unit_amount || 0,
|
||||
currentPeriodStart: item.current_period_start,
|
||||
currentPeriodEnd: item.current_period_end,
|
||||
tiers: item.price.tiers
|
||||
? JSON.stringify(item.price.tiers)
|
||||
: null,
|
||||
interval: item.plan.interval
|
||||
}));
|
||||
// First, get existing items to preserve featureId when there's no match
|
||||
const existingItems = await db
|
||||
.select()
|
||||
.from(subscriptionItems)
|
||||
.where(eq(subscriptionItems.subscriptionId, subscription.id));
|
||||
|
||||
const itemsToUpsert = fullSubscription.items.data.map((item) => {
|
||||
// Try to get featureId from price
|
||||
let featureId: string | null = getFeatureIdByPriceId(item.price.id) || null;
|
||||
|
||||
// If no match, try to preserve existing featureId
|
||||
if (!featureId) {
|
||||
const existingItem = existingItems.find(
|
||||
(ei) => ei.stripeSubscriptionItemId === item.id
|
||||
);
|
||||
featureId = existingItem?.featureId || null;
|
||||
}
|
||||
|
||||
return {
|
||||
stripeSubscriptionItemId: item.id,
|
||||
subscriptionId: subscription.id,
|
||||
planId: item.plan.id,
|
||||
priceId: item.price.id,
|
||||
featureId: featureId,
|
||||
meterId: item.plan.meter,
|
||||
unitAmount: item.price.unit_amount || 0,
|
||||
currentPeriodStart: item.current_period_start,
|
||||
currentPeriodEnd: item.current_period_end,
|
||||
tiers: item.price.tiers
|
||||
? JSON.stringify(item.price.tiers)
|
||||
: null,
|
||||
interval: item.plan.interval
|
||||
};
|
||||
});
|
||||
if (itemsToUpsert.length > 0) {
|
||||
await db.transaction(async (trx) => {
|
||||
await trx
|
||||
@@ -154,7 +188,7 @@ export async function handleSubscriptionUpdated(
|
||||
const orgId = customer.orgId;
|
||||
|
||||
if (!orgId) {
|
||||
logger.warn(
|
||||
logger.debug(
|
||||
`No orgId found in subscription metadata for subscription ${subscription.id}. Skipping usage reset.`
|
||||
);
|
||||
continue;
|
||||
@@ -234,17 +268,29 @@ export async function handleSubscriptionUpdated(
|
||||
}
|
||||
// --- end usage update ---
|
||||
|
||||
const type = getSubType(fullSubscription);
|
||||
if (type === "saas") {
|
||||
if (type === "tier1" || type === "tier2" || type === "tier3") {
|
||||
logger.debug(
|
||||
`Handling SAAS subscription lifecycle for org ${customer.orgId}`
|
||||
`Handling SAAS subscription lifecycle for org ${customer.orgId} with type ${type}`
|
||||
);
|
||||
// we only need to handle the limit lifecycle for saas subscriptions not for the licenses
|
||||
await handleSubscriptionLifesycle(
|
||||
customer.orgId,
|
||||
subscription.status
|
||||
subscription.status,
|
||||
type
|
||||
);
|
||||
} else {
|
||||
|
||||
// Handle feature lifecycle when subscription is canceled or becomes unpaid
|
||||
if (
|
||||
subscription.status === "canceled" ||
|
||||
subscription.status === "unpaid" ||
|
||||
subscription.status === "incomplete_expired"
|
||||
) {
|
||||
logger.info(
|
||||
`Subscription ${subscription.id} for org ${customer.orgId} is ${subscription.status}, disabling paid features`
|
||||
);
|
||||
await handleTierChange(customer.orgId, null, previousType ?? undefined);
|
||||
}
|
||||
} else if (type === "license") {
|
||||
if (subscription.status === "canceled" || subscription.status == "unpaid" || subscription.status == "incomplete_expired") {
|
||||
try {
|
||||
// WARNING:
|
||||
|
||||
@@ -11,8 +11,9 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./createCheckoutSessionSAAS";
|
||||
export * from "./createCheckoutSession";
|
||||
export * from "./createPortalSession";
|
||||
export * from "./getOrgSubscriptions";
|
||||
export * from "./getOrgUsage";
|
||||
export * from "./internalGetOrgTier";
|
||||
export * from "./changeTier";
|
||||
|
||||
@@ -13,38 +13,66 @@
|
||||
|
||||
import {
|
||||
freeLimitSet,
|
||||
tier1LimitSet,
|
||||
tier2LimitSet,
|
||||
tier3LimitSet,
|
||||
limitsService,
|
||||
subscribedLimitSet
|
||||
LimitSet
|
||||
} from "@server/lib/billing";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import logger from "@server/logger";
|
||||
import { SubscriptionType } from "./hooks/getSubType";
|
||||
|
||||
function getLimitSetForSubscriptionType(
|
||||
subType: SubscriptionType | null
|
||||
): LimitSet {
|
||||
switch (subType) {
|
||||
case "tier1":
|
||||
return tier1LimitSet;
|
||||
case "tier2":
|
||||
return tier2LimitSet;
|
||||
case "tier3":
|
||||
return tier3LimitSet;
|
||||
case "license":
|
||||
// License subscriptions use tier2 limits by default
|
||||
// This can be adjusted based on your business logic
|
||||
return tier2LimitSet;
|
||||
default:
|
||||
return freeLimitSet;
|
||||
}
|
||||
}
|
||||
|
||||
export async function handleSubscriptionLifesycle(
|
||||
orgId: string,
|
||||
status: string
|
||||
status: string,
|
||||
subType: SubscriptionType | null
|
||||
) {
|
||||
switch (status) {
|
||||
case "active":
|
||||
await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet);
|
||||
await usageService.checkLimitSet(orgId, true);
|
||||
const activeLimitSet = getLimitSetForSubscriptionType(subType);
|
||||
await limitsService.applyLimitSetToOrg(orgId, activeLimitSet);
|
||||
await usageService.checkLimitSet(orgId);
|
||||
break;
|
||||
case "canceled":
|
||||
// Subscription canceled - revert to free tier
|
||||
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
|
||||
await usageService.checkLimitSet(orgId, true);
|
||||
await usageService.checkLimitSet(orgId);
|
||||
break;
|
||||
case "past_due":
|
||||
// Optionally handle past due status, e.g., notify customer
|
||||
// Payment past due - keep current limits but notify customer
|
||||
// Limits will revert to free tier if it becomes unpaid
|
||||
break;
|
||||
case "unpaid":
|
||||
// Subscription unpaid - revert to free tier
|
||||
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
|
||||
await usageService.checkLimitSet(orgId, true);
|
||||
await usageService.checkLimitSet(orgId);
|
||||
break;
|
||||
case "incomplete":
|
||||
// Optionally handle incomplete status, e.g., notify customer
|
||||
// Payment incomplete - give them time to complete payment
|
||||
break;
|
||||
case "incomplete_expired":
|
||||
// Payment never completed - revert to free tier
|
||||
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
|
||||
await usageService.checkLimitSet(orgId, true);
|
||||
await usageService.checkLimitSet(orgId);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
||||
@@ -31,7 +31,8 @@ import {
|
||||
verifyUserHasAction,
|
||||
verifyUserIsServerAdmin,
|
||||
verifySiteAccess,
|
||||
verifyClientAccess
|
||||
verifyClientAccess,
|
||||
verifyLimits
|
||||
} from "@server/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
import {
|
||||
@@ -52,6 +53,7 @@ import {
|
||||
authenticated as a,
|
||||
authRouter as aa
|
||||
} from "@server/routers/external";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
|
||||
export const authenticated = a;
|
||||
export const unauthenticated = ua;
|
||||
@@ -76,7 +78,9 @@ unauthenticated.post(
|
||||
authenticated.put(
|
||||
"/org/:orgId/idp/oidc",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.orgOidc),
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createIdp),
|
||||
logActionAudit(ActionsEnum.createIdp),
|
||||
orgIdp.createOrgOidcIdp
|
||||
@@ -85,8 +89,10 @@ authenticated.put(
|
||||
authenticated.post(
|
||||
"/org/:orgId/idp/:idpId/oidc",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.orgOidc),
|
||||
verifyOrgAccess,
|
||||
verifyIdpAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateIdp),
|
||||
logActionAudit(ActionsEnum.updateIdp),
|
||||
orgIdp.updateOrgOidcIdp
|
||||
@@ -135,35 +141,27 @@ authenticated.post(
|
||||
verifyValidLicense,
|
||||
verifyOrgAccess,
|
||||
verifyCertificateAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.restartCertificate),
|
||||
logActionAudit(ActionsEnum.restartCertificate),
|
||||
certificates.restartCertificate
|
||||
);
|
||||
|
||||
if (build === "saas") {
|
||||
unauthenticated.post(
|
||||
"/quick-start",
|
||||
rateLimit({
|
||||
windowMs: 15 * 60 * 1000,
|
||||
max: 100,
|
||||
keyGenerator: (req) => req.path,
|
||||
handler: (req, res, next) => {
|
||||
const message = `We're too busy right now. Please try again later.`;
|
||||
return next(
|
||||
createHttpError(HttpCode.TOO_MANY_REQUESTS, message)
|
||||
);
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
auth.quickStart
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/org/:orgId/billing/create-checkout-session-saas",
|
||||
"/org/:orgId/billing/create-checkout-session",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.billing),
|
||||
logActionAudit(ActionsEnum.billing),
|
||||
billing.createCheckoutSessionSAAS
|
||||
billing.createCheckoutSession
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/org/:orgId/billing/change-tier",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.billing),
|
||||
logActionAudit(ActionsEnum.billing),
|
||||
billing.changeTier
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
@@ -243,6 +241,7 @@ authenticated.put(
|
||||
"/org/:orgId/remote-exit-node",
|
||||
verifyValidLicense,
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createRemoteExitNode),
|
||||
logActionAudit(ActionsEnum.createRemoteExitNode),
|
||||
remoteExitNode.createRemoteExitNode
|
||||
@@ -286,7 +285,9 @@ authenticated.delete(
|
||||
authenticated.put(
|
||||
"/org/:orgId/login-page",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.loginPageDomain),
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createLoginPage),
|
||||
logActionAudit(ActionsEnum.createLoginPage),
|
||||
loginPage.createLoginPage
|
||||
@@ -295,8 +296,10 @@ authenticated.put(
|
||||
authenticated.post(
|
||||
"/org/:orgId/login-page/:loginPageId",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.loginPageDomain),
|
||||
verifyOrgAccess,
|
||||
verifyLoginPageAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateLoginPage),
|
||||
logActionAudit(ActionsEnum.updateLoginPage),
|
||||
loginPage.updateLoginPage
|
||||
@@ -323,6 +326,7 @@ authenticated.get(
|
||||
authenticated.get(
|
||||
"/org/:orgId/approvals",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.deviceApprovals),
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.listApprovals),
|
||||
logActionAudit(ActionsEnum.listApprovals),
|
||||
@@ -339,7 +343,9 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
"/org/:orgId/approvals/:approvalId",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.deviceApprovals),
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateApprovals),
|
||||
logActionAudit(ActionsEnum.updateApprovals),
|
||||
approval.processPendingApproval
|
||||
@@ -348,6 +354,7 @@ authenticated.put(
|
||||
authenticated.get(
|
||||
"/org/:orgId/login-page-branding",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.loginPageBranding),
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.getLoginPage),
|
||||
logActionAudit(ActionsEnum.getLoginPage),
|
||||
@@ -357,7 +364,9 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
"/org/:orgId/login-page-branding",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.loginPageBranding),
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateLoginPage),
|
||||
logActionAudit(ActionsEnum.updateLoginPage),
|
||||
loginPage.upsertLoginPageBranding
|
||||
@@ -433,7 +442,7 @@ authenticated.post(
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/action",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyValidSubscription(tierMatrix.actionLogs),
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.exportLogs),
|
||||
logs.queryActionAuditLogs
|
||||
@@ -442,7 +451,7 @@ authenticated.get(
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/action/export",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyValidSubscription(tierMatrix.logExport),
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.exportLogs),
|
||||
logActionAudit(ActionsEnum.exportLogs),
|
||||
@@ -452,7 +461,7 @@ authenticated.get(
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/access",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyValidSubscription(tierMatrix.accessLogs),
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.exportLogs),
|
||||
logs.queryAccessAuditLogs
|
||||
@@ -461,7 +470,7 @@ authenticated.get(
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/access/export",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyValidSubscription(tierMatrix.logExport),
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.exportLogs),
|
||||
logActionAudit(ActionsEnum.exportLogs),
|
||||
@@ -470,18 +479,20 @@ authenticated.get(
|
||||
|
||||
authenticated.post(
|
||||
"/re-key/:clientId/regenerate-client-secret",
|
||||
verifyClientAccess, // this is first to set the org id
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyValidSubscription(tierMatrix.rotateCredentials),
|
||||
verifyClientAccess, // this is first to set the org id
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateClientSecret
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/re-key/:siteId/regenerate-site-secret",
|
||||
verifySiteAccess, // this is first to set the org id
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyValidSubscription(tierMatrix.rotateCredentials),
|
||||
verifySiteAccess, // this is first to set the org id
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateSiteSecret
|
||||
);
|
||||
@@ -489,8 +500,9 @@ authenticated.post(
|
||||
authenticated.put(
|
||||
"/re-key/:orgId/regenerate-remote-exit-node-secret",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyValidSubscription(tierMatrix.rotateCredentials),
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateExitNodeSecret
|
||||
);
|
||||
|
||||
@@ -134,6 +134,7 @@ export async function generateNewEnterpriseLicense(
|
||||
], // Start with the standard feature set that matches the free limits
|
||||
customer: customer.customerId,
|
||||
mode: "subscription",
|
||||
allow_promotion_codes: true,
|
||||
success_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/license?success=true&session_id={CHECKOUT_SESSION_ID}`,
|
||||
cancel_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/license?canceled=true`
|
||||
});
|
||||
|
||||
@@ -19,21 +19,20 @@ import {
|
||||
verifyApiKeyHasAction,
|
||||
verifyApiKeyIsRoot,
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyIdpAccess
|
||||
verifyApiKeyIdpAccess,
|
||||
verifyLimits
|
||||
} from "@server/middlewares";
|
||||
import {
|
||||
verifyValidSubscription,
|
||||
verifyValidLicense
|
||||
} from "#private/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
|
||||
import {
|
||||
unauthenticated as ua,
|
||||
authenticated as a
|
||||
} from "@server/routers/integration";
|
||||
import { logActionAudit } from "#private/middlewares";
|
||||
import config from "#private/lib/config";
|
||||
import { build } from "@server/build";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
|
||||
export const unauthenticated = ua;
|
||||
export const authenticated = a;
|
||||
@@ -57,7 +56,7 @@ authenticated.delete(
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/action",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyValidSubscription(tierMatrix.actionLogs),
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.exportLogs),
|
||||
logs.queryActionAuditLogs
|
||||
@@ -66,7 +65,7 @@ authenticated.get(
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/action/export",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyValidSubscription(tierMatrix.logExport),
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.exportLogs),
|
||||
logActionAudit(ActionsEnum.exportLogs),
|
||||
@@ -76,7 +75,7 @@ authenticated.get(
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/access",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyValidSubscription(tierMatrix.accessLogs),
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.exportLogs),
|
||||
logs.queryAccessAuditLogs
|
||||
@@ -85,7 +84,7 @@ authenticated.get(
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/access/export",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyValidSubscription(tierMatrix.logExport),
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.exportLogs),
|
||||
logActionAudit(ActionsEnum.exportLogs),
|
||||
@@ -95,7 +94,9 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
"/org/:orgId/idp/oidc",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.orgOidc),
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createIdp),
|
||||
logActionAudit(ActionsEnum.createIdp),
|
||||
orgIdp.createOrgOidcIdp
|
||||
@@ -104,8 +105,10 @@ authenticated.put(
|
||||
authenticated.post(
|
||||
"/org/:orgId/idp/:idpId/oidc",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.orgOidc),
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyIdpAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateIdp),
|
||||
logActionAudit(ActionsEnum.updateIdp),
|
||||
orgIdp.updateOrgOidcIdp
|
||||
|
||||
@@ -30,9 +30,7 @@ import { fromError } from "zod-validation-error";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { validateAndConstructDomain } from "@server/lib/domainUtils";
|
||||
import { createCertificate } from "#private/routers/certificates/createCertificate";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { build } from "@server/build";
|
||||
|
||||
import { CreateLoginPageResponse } from "@server/routers/loginPage/types";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
@@ -76,19 +74,6 @@ export async function createLoginPage(
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const [existing] = await db
|
||||
.select()
|
||||
.from(loginPageOrg)
|
||||
|
||||
@@ -25,9 +25,7 @@ import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { build } from "@server/build";
|
||||
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
@@ -53,18 +51,6 @@ export async function deleteLoginPageBranding(
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const [existingLoginPageBranding] = await db
|
||||
.select()
|
||||
|
||||
@@ -25,9 +25,7 @@ import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { build } from "@server/build";
|
||||
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
@@ -51,19 +49,6 @@ export async function getLoginPageBranding(
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const [existingLoginPageBranding] = await db
|
||||
.select()
|
||||
.from(loginPageBranding)
|
||||
|
||||
@@ -23,9 +23,7 @@ import { eq, and } from "drizzle-orm";
|
||||
import { validateAndConstructDomain } from "@server/lib/domainUtils";
|
||||
import { subdomainSchema } from "@server/lib/schemas";
|
||||
import { createCertificate } from "#private/routers/certificates/createCertificate";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { build } from "@server/build";
|
||||
|
||||
import { UpdateLoginPageResponse } from "@server/routers/loginPage/types";
|
||||
|
||||
const paramsSchema = z
|
||||
@@ -87,18 +85,6 @@ export async function updateLoginPage(
|
||||
|
||||
const { loginPageId, orgId } = parsedParams.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const [existingLoginPage] = await db
|
||||
.select()
|
||||
|
||||
@@ -25,10 +25,8 @@ import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { eq, InferInsertModel } from "drizzle-orm";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { build } from "@server/build";
|
||||
import config from "@server/private/lib/config";
|
||||
import config from "#private/lib/config";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
@@ -128,19 +126,6 @@ export async function upsertLoginPageBranding(
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let updateData = parsedBody.data satisfies InferInsertModel<
|
||||
typeof loginPageBranding
|
||||
>;
|
||||
|
||||
@@ -24,10 +24,10 @@ import { idp, idpOidcConfig, idpOrg, orgs } from "@server/db";
|
||||
import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
|
||||
import { encrypt } from "@server/lib/crypto";
|
||||
import config from "@server/lib/config";
|
||||
import { build } from "@server/build";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types";
|
||||
import { isSubscribed } from "#private/lib/isSubscribed";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
import privateConfig from "#private/lib/config";
|
||||
|
||||
const paramsSchema = z.strictObject({ orgId: z.string().nonempty() });
|
||||
|
||||
@@ -93,6 +93,18 @@ export async function createOrgOidcIdp(
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
privateConfig.getRawPrivateConfig().app.identity_provider_mode !==
|
||||
"org"
|
||||
) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Organization-specific IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'org' in the private configuration to enable this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const {
|
||||
clientId,
|
||||
clientSecret,
|
||||
@@ -103,23 +115,19 @@ export async function createOrgOidcIdp(
|
||||
emailPath,
|
||||
namePath,
|
||||
name,
|
||||
autoProvision,
|
||||
variant,
|
||||
roleMapping,
|
||||
tags
|
||||
} = parsedBody.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier, active } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
let { autoProvision } = parsedBody.data;
|
||||
|
||||
const subscribed = await isSubscribed(
|
||||
orgId,
|
||||
tierMatrix.deviceApprovals
|
||||
);
|
||||
if (!subscribed) {
|
||||
autoProvision = false;
|
||||
}
|
||||
|
||||
const key = config.getRawConfig().server.secret!;
|
||||
|
||||
@@ -22,6 +22,7 @@ import { fromError } from "zod-validation-error";
|
||||
import { idp, idpOidcConfig, idpOrg } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import privateConfig from "#private/lib/config";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
@@ -59,6 +60,18 @@ export async function deleteOrgIdp(
|
||||
|
||||
const { idpId } = parsedParams.data;
|
||||
|
||||
if (
|
||||
privateConfig.getRawPrivateConfig().app.identity_provider_mode !==
|
||||
"org"
|
||||
) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Organization-specific IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'org' in the private configuration to enable this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Check if IDP exists
|
||||
const [existingIdp] = await db
|
||||
.select()
|
||||
|
||||
@@ -24,9 +24,9 @@ import { idp, idpOidcConfig } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { encrypt } from "@server/lib/crypto";
|
||||
import config from "@server/lib/config";
|
||||
import { build } from "@server/build";
|
||||
import { getOrgTierData } from "#private/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { isSubscribed } from "#private/lib/isSubscribed";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
import privateConfig from "#private/lib/config";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
@@ -98,6 +98,18 @@ export async function updateOrgOidcIdp(
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
privateConfig.getRawPrivateConfig().app.identity_provider_mode !==
|
||||
"org"
|
||||
) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Organization-specific IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'org' in the private configuration to enable this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { idpId, orgId } = parsedParams.data;
|
||||
const {
|
||||
clientId,
|
||||
@@ -109,22 +121,18 @@ export async function updateOrgOidcIdp(
|
||||
emailPath,
|
||||
namePath,
|
||||
name,
|
||||
autoProvision,
|
||||
roleMapping,
|
||||
tags
|
||||
} = parsedBody.data;
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier, active } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
let { autoProvision } = parsedBody.data;
|
||||
|
||||
const subscribed = await isSubscribed(
|
||||
orgId,
|
||||
tierMatrix.deviceApprovals
|
||||
);
|
||||
if (!subscribed) {
|
||||
autoProvision = false;
|
||||
}
|
||||
|
||||
// Check if IDP exists and is of type OIDC
|
||||
|
||||
@@ -85,7 +85,7 @@ export async function createRemoteExitNode(
|
||||
if (usage) {
|
||||
const rejectRemoteExitNodes = await usageService.checkLimitSet(
|
||||
orgId,
|
||||
false,
|
||||
|
||||
FeatureId.REMOTE_EXIT_NODES,
|
||||
{
|
||||
...usage,
|
||||
@@ -97,7 +97,7 @@ export async function createRemoteExitNode(
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Remote exit node limit exceeded. Please upgrade your plan or contact us at support@pangolin.net"
|
||||
"Remote node limit exceeded. Please upgrade your plan."
|
||||
)
|
||||
);
|
||||
}
|
||||
@@ -224,7 +224,7 @@ export async function createRemoteExitNode(
|
||||
});
|
||||
|
||||
if (numExitNodeOrgs) {
|
||||
await usageService.updateDaily(
|
||||
await usageService.updateCount(
|
||||
orgId,
|
||||
FeatureId.REMOTE_EXIT_NODES,
|
||||
numExitNodeOrgs.length
|
||||
|
||||
@@ -106,7 +106,7 @@ export async function deleteRemoteExitNode(
|
||||
});
|
||||
|
||||
if (numExitNodeOrgs) {
|
||||
await usageService.updateDaily(
|
||||
await usageService.updateCount(
|
||||
orgId,
|
||||
FeatureId.REMOTE_EXIT_NODES,
|
||||
numExitNodeOrgs.length
|
||||
|
||||
Reference in New Issue
Block a user