Merge branch 'dev' into refactor/paginated-tables

This commit is contained in:
Fred KISSIE
2026-02-13 06:03:09 +01:00
184 changed files with 6127 additions and 4176 deletions

View File

@@ -11,46 +11,59 @@
* This file is not licensed under the AGPLv3.
*/
import { getTierPriceSet } from "@server/lib/billing/tiers";
import { getOrgSubscriptionsData } from "@server/private/routers/billing/getOrgSubscriptions";
import { build } from "@server/build";
import { db, customers, subscriptions } from "@server/db";
import { Tier } from "@server/types/Tiers";
import { eq, and, ne } from "drizzle-orm";
export async function getOrgTierData(
orgId: string
): Promise<{ tier: string | null; active: boolean }> {
let tier = null;
): Promise<{ tier: Tier | null; active: boolean }> {
let tier: Tier | null = null;
let active = false;
if (build !== "saas") {
return { tier, active };
}
// TODO: THIS IS INEFFICIENT!!! WE SHOULD IMPROVE HOW WE STORE TIERS WITH SUBSCRIPTIONS AND RETRIEVE THEM
try {
// Get customer for org
const [customer] = await db
.select()
.from(customers)
.where(eq(customers.orgId, orgId))
.limit(1);
const subscriptionsWithItems = await getOrgSubscriptionsData(orgId);
if (customer) {
// Query for active subscriptions that are not license type
const [subscription] = await db
.select()
.from(subscriptions)
.where(
and(
eq(subscriptions.customerId, customer.customerId),
eq(subscriptions.status, "active"),
ne(subscriptions.type, "license")
)
)
.limit(1);
for (const { subscription, items } of subscriptionsWithItems) {
if (items && items.length > 0) {
const tierPriceSet = getTierPriceSet();
// Iterate through tiers in order (earlier keys are higher tiers)
for (const [tierId, priceId] of Object.entries(tierPriceSet)) {
// Check if any subscription item matches this tier's price ID
const matchingItem = items.find((item) => item.priceId === priceId);
if (matchingItem) {
tier = tierId;
break;
if (subscription) {
// Validate that subscription.type is one of the expected tier values
if (
subscription.type === "tier1" ||
subscription.type === "tier2" ||
subscription.type === "tier3"
) {
tier = subscription.type;
active = true;
}
}
}
if (subscription && subscription.status === "active") {
active = true;
}
// If we found a tier and active subscription, we can stop
if (tier && active) {
break;
}
} catch (error) {
// If org not found or error occurs, return null tier and inactive
// This is acceptable behavior as per the function signature
}
return { tier, active };
}

View File

@@ -13,8 +13,6 @@
import { build } from "@server/build";
import { db, Org, orgs, ResourceSession, sessions, users } from "@server/db";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import license from "#private/license/license";
import { eq } from "drizzle-orm";
import {
@@ -80,6 +78,8 @@ export async function checkOrgAccessPolicy(
}
}
// TODO: check that the org is subscribed
// get the needed data
if (!props.org) {

View File

@@ -65,6 +65,11 @@ export class PrivateConfig {
this.rawPrivateConfig.branding?.logo?.dark_path || undefined;
}
if (this.rawPrivateConfig.app.identity_provider_mode) {
process.env.IDENTITY_PROVIDER_MODE =
this.rawPrivateConfig.app.identity_provider_mode;
}
process.env.BRANDING_LOGO_AUTH_WIDTH = this.rawPrivateConfig.branding
?.logo?.auth_page?.width
? this.rawPrivateConfig.branding?.logo?.auth_page?.width.toString()
@@ -125,24 +130,10 @@ export class PrivateConfig {
this.rawPrivateConfig.server.reo_client_id;
}
if (this.rawPrivateConfig.stripe?.s3Bucket) {
process.env.S3_BUCKET = this.rawPrivateConfig.stripe.s3Bucket;
}
if (this.rawPrivateConfig.stripe?.localFilePath) {
process.env.LOCAL_FILE_PATH =
this.rawPrivateConfig.stripe.localFilePath;
}
if (this.rawPrivateConfig.stripe?.s3Region) {
process.env.S3_REGION = this.rawPrivateConfig.stripe.s3Region;
}
if (this.rawPrivateConfig.flags.use_pangolin_dns) {
process.env.USE_PANGOLIN_DNS =
this.rawPrivateConfig.flags.use_pangolin_dns.toString();
}
if (this.rawPrivateConfig.flags.use_org_only_idp) {
process.env.USE_ORG_ONLY_IDP =
this.rawPrivateConfig.flags.use_org_only_idp.toString();
}
}
public getRawPrivateConfig() {

View File

@@ -13,18 +13,20 @@
import { build } from "@server/build";
import license from "#private/license/license";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { isSubscribed } from "#private/lib/isSubscribed";
import { Tier } from "@server/types/Tiers";
export async function isLicensedOrSubscribed(orgId: string): Promise<boolean> {
export async function isLicensedOrSubscribed(
orgId: string,
tiers: Tier[]
): Promise<boolean> {
if (build === "enterprise") {
return await license.isUnlocked();
}
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
return tier === TierId.STANDARD;
return isSubscribed(orgId, tiers);
}
return false;
}
}

View File

@@ -0,0 +1,29 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
import { Tier } from "@server/types/Tiers";
export async function isSubscribed(
orgId: string,
tiers: Tier[]
): Promise<boolean> {
if (build === "saas") {
const { tier, active } = await getOrgTierData(orgId);
const isTier = (tier && tiers.includes(tier)) || false;
return active && isTier;
}
return false;
}

View File

@@ -25,7 +25,8 @@ export const privateConfigSchema = z.object({
app: z
.object({
region: z.string().optional().default("default"),
base_domain: z.string().optional()
base_domain: z.string().optional(),
identity_provider_mode: z.enum(["global", "org"]).optional()
})
.optional()
.default({
@@ -95,7 +96,7 @@ export const privateConfigSchema = z.object({
.object({
enable_redis: z.boolean().optional().default(false),
use_pangolin_dns: z.boolean().optional().default(false),
use_org_only_idp: z.boolean().optional().default(false)
use_org_only_idp: z.boolean().optional()
})
.optional()
.prefault({}),
@@ -176,12 +177,34 @@ export const privateConfigSchema = z.object({
.string()
.optional()
.transform(getEnvOrYaml("STRIPE_WEBHOOK_SECRET")),
s3Bucket: z.string(),
s3Region: z.string().default("us-east-1"),
localFilePath: z.string()
// s3Bucket: z.string(),
// s3Region: z.string().default("us-east-1"),
// localFilePath: z.string().optional()
})
.optional()
});
})
.transform((data) => {
// this to maintain backwards compatibility with the old config file
const identityProviderMode = data.app?.identity_provider_mode;
const useOrgOnlyIdp = data.flags?.use_org_only_idp;
if (identityProviderMode !== undefined) {
return data;
}
if (useOrgOnlyIdp === true) {
return {
...data,
app: { ...data.app, identity_provider_mode: "org" as const }
};
}
if (useOrgOnlyIdp === false) {
return {
...data,
app: { ...data.app, identity_provider_mode: "global" as const }
};
}
return data;
});
export function readPrivateConfigFile() {
if (build == "oss") {

View File

@@ -16,46 +16,61 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
import { Tier } from "@server/types/Tiers";
export function verifyValidSubscription(tiers: Tier[]) {
return async function (
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
if (build != "saas") {
return next();
}
const orgId =
req.params.orgId ||
req.body.orgId ||
req.query.orgId ||
req.userOrgId;
if (!orgId) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Organization ID is required to verify subscription"
)
);
}
const { tier, active } = await getOrgTierData(orgId);
const isTier = tiers.includes(tier as Tier);
if (!active) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Organization does not have an active subscription"
)
);
}
if (!isTier) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Organization subscription tier does not have access to this feature"
)
);
}
export async function verifyValidSubscription(
req: Request,
res: Response,
next: NextFunction
) {
try {
if (build != "saas") {
return next();
}
const orgId = req.params.orgId || req.body.orgId || req.query.orgId || req.userOrgId;
if (!orgId) {
} catch (e) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Organization ID is required to verify subscription"
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying subscription"
)
);
}
const tier = await getOrgTierData(orgId);
if (!tier.active) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Organization does not have an active subscription"
)
);
}
return next();
} catch (e) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Error verifying subscription"
)
);
}
};
}

View File

@@ -19,8 +19,6 @@ import { fromError } from "zod-validation-error";
import type { Request, Response, NextFunction } from "express";
import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import {
approvals,
clients,
@@ -273,19 +271,6 @@ export async function listApprovals(
const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const { approvalsList, nextCursorPending, nextCursorTimestamp } =
await queryApprovals({
orgId: orgId.toString(),

View File

@@ -17,10 +17,7 @@ import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
import { build } from "@server/build";
import { approvals, clients, db, orgs, type Approval } from "@server/db";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import response from "@server/lib/response";
import { and, eq, type InferInsertModel } from "drizzle-orm";
import type { NextFunction, Request, Response } from "express";
@@ -64,20 +61,6 @@ export async function processPendingApproval(
}
const { orgId, approvalId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const updateData = parsedBody.data;
const approval = await db

View File

@@ -13,4 +13,3 @@
export * from "./transferSession";
export * from "./getSessionTransferToken";
export * from "./quickStart";

View File

@@ -1,585 +0,0 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { NextFunction, Request, Response } from "express";
import {
account,
db,
domainNamespaces,
domains,
exitNodes,
newts,
newtSessions,
orgs,
passwordResetTokens,
Resource,
resourcePassword,
resourcePincode,
resources,
resourceWhitelist,
roleResources,
roles,
roleSites,
sites,
targetHealthCheck,
targets,
userResources,
userSites
} from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { users } from "@server/db";
import { fromError } from "zod-validation-error";
import createHttpError from "http-errors";
import response from "@server/lib/response";
import { SqliteError } from "better-sqlite3";
import { eq, and, sql } from "drizzle-orm";
import moment from "moment";
import { generateId } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import logger from "@server/logger";
import { hashPassword } from "@server/auth/password";
import { UserType } from "@server/types/UserTypes";
import { createUserAccountOrg } from "@server/lib/createUserAccountOrg";
import { sendEmail } from "@server/emails";
import WelcomeQuickStart from "@server/emails/templates/WelcomeQuickStart";
import { alphabet, generateRandomString } from "oslo/crypto";
import { createDate, TimeSpan } from "oslo";
import { getUniqueResourceName, getUniqueSiteName } from "@server/db/names";
import { pickPort } from "@server/routers/target/helpers";
import { addTargets } from "@server/routers/newt/targets";
import { isTargetValid } from "@server/lib/validators";
import { listExitNodes } from "#private/lib/exitNodes";
const bodySchema = z.object({
email: z.email().toLowerCase(),
ip: z.string().refine(isTargetValid),
method: z.enum(["http", "https"]),
port: z.int().min(1).max(65535),
pincode: z
.string()
.regex(/^\d{6}$/)
.optional(),
password: z.string().min(4).max(100).optional(),
enableWhitelist: z.boolean().optional().default(true),
animalId: z.string() // This is actually the secret key for the backend
});
export type QuickStartBody = z.infer<typeof bodySchema>;
export type QuickStartResponse = {
newtId: string;
newtSecret: string;
resourceUrl: string;
completeSignUpLink: string;
};
const DEMO_UBO_KEY = "b460293f-347c-4b30-837d-4e06a04d5a22";
export async function quickStart(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const {
email,
ip,
method,
port,
pincode,
password,
enableWhitelist,
animalId
} = parsedBody.data;
try {
const tokenValidation = validateTokenOnApi(animalId);
if (!tokenValidation.isValid) {
logger.warn(
`Quick start failed for ${email} token ${animalId}: ${tokenValidation.message}`
);
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid or expired token"
)
);
}
if (animalId === DEMO_UBO_KEY) {
if (email !== "mehrdad@getubo.com") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid email for demo Ubo key"
)
);
}
const [existing] = await db
.select()
.from(users)
.where(
and(
eq(users.email, email),
eq(users.type, UserType.Internal)
)
);
if (existing) {
// delete the user if it already exists
await db.delete(users).where(eq(users.userId, existing.userId));
const orgId = `org_${existing.userId}`;
await db.delete(orgs).where(eq(orgs.orgId, orgId));
}
}
const tempPassword = generateId(15);
const passwordHash = await hashPassword(tempPassword);
const userId = generateId(15);
// TODO: see if that user already exists?
// Create the sandbox user
const existing = await db
.select()
.from(users)
.where(
and(eq(users.email, email), eq(users.type, UserType.Internal))
);
if (existing && existing.length > 0) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"A user with that email address already exists"
)
);
}
let newtId: string;
let secret: string;
let fullDomain: string;
let resource: Resource;
let completeSignUpLink: string;
await db.transaction(async (trx) => {
await trx.insert(users).values({
userId: userId,
type: UserType.Internal,
username: email,
email: email,
passwordHash,
dateCreated: moment().toISOString()
});
// create user"s account
await trx.insert(account).values({
userId
});
});
const { success, error, org } = await createUserAccountOrg(
userId,
email
);
if (!success) {
if (error) {
throw new Error(error);
}
throw new Error("Failed to create user account and organization");
}
if (!org) {
throw new Error("Failed to create user account and organization");
}
const orgId = org.orgId;
await db.transaction(async (trx) => {
const token = generateRandomString(
8,
alphabet("0-9", "A-Z", "a-z")
);
await trx
.delete(passwordResetTokens)
.where(eq(passwordResetTokens.userId, userId));
const tokenHash = await hashPassword(token);
await trx.insert(passwordResetTokens).values({
userId: userId,
email: email,
tokenHash,
expiresAt: createDate(new TimeSpan(7, "d")).getTime()
});
// // Create the sandbox newt
// const newClientAddress = await getNextAvailableClientSubnet(orgId);
// if (!newClientAddress) {
// throw new Error("No available subnet found");
// }
// const clientAddress = newClientAddress.split("/")[0];
newtId = generateId(15);
secret = generateId(48);
// Create the sandbox site
const siteNiceId = await getUniqueSiteName(orgId);
const siteName = `First Site`;
// pick a random exit node
const exitNodesList = await listExitNodes(orgId);
// select a random exit node
const randomExitNode =
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
if (!randomExitNode) {
throw new Error("No exit nodes available");
}
const [newSite] = await trx
.insert(sites)
.values({
orgId,
exitNodeId: randomExitNode.exitNodeId,
name: siteName,
niceId: siteNiceId,
// address: clientAddress,
type: "newt",
dockerSocketEnabled: true
})
.returning();
const siteId = newSite.siteId;
const adminRole = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (adminRole.length === 0) {
throw new Error("Admin role not found");
}
await trx.insert(roleSites).values({
roleId: adminRole[0].roleId,
siteId: newSite.siteId
});
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the site
await trx.insert(userSites).values({
userId: req.user?.userId!,
siteId: newSite.siteId
});
}
// add the peer to the exit node
const secretHash = await hashPassword(secret!);
await trx.insert(newts).values({
newtId: newtId!,
secretHash,
siteId: newSite.siteId,
dateCreated: moment().toISOString()
});
const [randomNamespace] = await trx
.select()
.from(domainNamespaces)
.orderBy(sql`RANDOM()`)
.limit(1);
if (!randomNamespace) {
throw new Error("No domain namespace available");
}
const [randomNamespaceDomain] = await trx
.select()
.from(domains)
.where(eq(domains.domainId, randomNamespace.domainId))
.limit(1);
if (!randomNamespaceDomain) {
throw new Error("No domain found for the namespace");
}
const resourceNiceId = await getUniqueResourceName(orgId);
// Create sandbox resource
const subdomain = `${resourceNiceId}-${generateId(5)}`;
fullDomain = `${subdomain}.${randomNamespaceDomain.baseDomain}`;
const resourceName = `First Resource`;
const newResource = await trx
.insert(resources)
.values({
niceId: resourceNiceId,
fullDomain,
domainId: randomNamespaceDomain.domainId,
orgId,
name: resourceName,
subdomain,
http: true,
protocol: "tcp",
ssl: true,
sso: false,
emailWhitelistEnabled: enableWhitelist
})
.returning();
await trx.insert(roleResources).values({
roleId: adminRole[0].roleId,
resourceId: newResource[0].resourceId
});
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
// make sure the user can access the resource
await trx.insert(userResources).values({
userId: req.user?.userId!,
resourceId: newResource[0].resourceId
});
}
resource = newResource[0];
// Create the sandbox target
const { internalPort, targetIps } = await pickPort(siteId!, trx);
if (!internalPort) {
throw new Error("No available internal port");
}
const newTarget = await trx
.insert(targets)
.values({
resourceId: resource.resourceId,
siteId: siteId!,
internalPort,
ip,
method,
port,
enabled: true
})
.returning();
const newHealthcheck = await trx
.insert(targetHealthCheck)
.values({
targetId: newTarget[0].targetId,
hcEnabled: false
})
.returning();
// add the new target to the targetIps array
targetIps.push(`${ip}/32`);
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, siteId!))
.limit(1);
await addTargets(
newt.newtId,
newTarget,
newHealthcheck,
resource.protocol
);
// Set resource pincode if provided
if (pincode) {
await trx
.delete(resourcePincode)
.where(
eq(resourcePincode.resourceId, resource!.resourceId)
);
const pincodeHash = await hashPassword(pincode);
await trx.insert(resourcePincode).values({
resourceId: resource!.resourceId,
pincodeHash,
digitLength: 6
});
}
// Set resource password if provided
if (password) {
await trx
.delete(resourcePassword)
.where(
eq(resourcePassword.resourceId, resource!.resourceId)
);
const passwordHash = await hashPassword(password);
await trx.insert(resourcePassword).values({
resourceId: resource!.resourceId,
passwordHash
});
}
// Set resource OTP if whitelist is enabled
if (enableWhitelist) {
await trx.insert(resourceWhitelist).values({
email,
resourceId: resource!.resourceId
});
}
completeSignUpLink = `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}&token=${token}`;
// Store token for email outside transaction
await sendEmail(
WelcomeQuickStart({
username: email,
link: completeSignUpLink,
fallbackLink: `${config.getRawConfig().app.dashboard_url}/auth/reset-password?quickstart=true&email=${email}`,
resourceMethod: method,
resourceHostname: ip,
resourcePort: port,
resourceUrl: `https://${fullDomain}`,
cliCommand: `newt --id ${newtId} --secret ${secret}`
}),
{
to: email,
from: config.getNoReplyEmail(),
subject: `Access your Pangolin dashboard and resources`
}
);
});
return response<QuickStartResponse>(res, {
data: {
newtId: newtId!,
newtSecret: secret!,
resourceUrl: `https://${fullDomain!}`,
completeSignUpLink: completeSignUpLink!
},
success: true,
error: false,
message: "Quick start completed successfully",
status: HttpCode.OK
});
} catch (e) {
if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
`Account already exists with that email. Email: ${email}. IP: ${req.ip}.`
);
}
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"A user with that email address already exists"
)
);
} else {
logger.error(e);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to do quick start"
)
);
}
}
}
const BACKEND_SECRET_KEY = "4f9b6000-5d1a-11f0-9de7-ff2cc032f501";
/**
* Validates a token received from the frontend.
* @param {string} token The validation token from the request.
* @returns {{ isValid: boolean; message: string }} An object indicating if the token is valid.
*/
const validateTokenOnApi = (
token: string
): { isValid: boolean; message: string } => {
if (token === DEMO_UBO_KEY) {
// Special case for demo UBO key
return { isValid: true, message: "Demo UBO key is valid." };
}
if (!token) {
return { isValid: false, message: "Error: No token provided." };
}
try {
// 1. Decode the base64 string
const decodedB64 = atob(token);
// 2. Reverse the character code manipulation
const deobfuscated = decodedB64
.split("")
.map((char) => String.fromCharCode(char.charCodeAt(0) - 5)) // Reverse the shift
.join("");
// 3. Split the data to get the original secret and timestamp
const parts = deobfuscated.split("|");
if (parts.length !== 2) {
throw new Error("Invalid token format.");
}
const receivedKey = parts[0];
const tokenTimestamp = parseInt(parts[1], 10);
// 4. Check if the secret key matches
if (receivedKey !== BACKEND_SECRET_KEY) {
return { isValid: false, message: "Invalid token: Key mismatch." };
}
// 5. Check if the timestamp is recent (e.g., within 30 seconds) to prevent replay attacks
const now = Date.now();
const timeDifference = now - tokenTimestamp;
if (timeDifference > 30000) {
// 30 seconds
return { isValid: false, message: "Invalid token: Expired." };
}
if (timeDifference < 0) {
// Timestamp is in the future
return {
isValid: false,
message: "Invalid token: Timestamp is in the future."
};
}
// If all checks pass, the token is valid
return { isValid: true, message: "Token is valid!" };
} catch (error) {
// This will catch errors from atob (if not valid base64) or other issues.
return {
isValid: false,
message: `Error: ${(error as Error).message}`
};
}
};

View File

@@ -0,0 +1,268 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { customers, db, subscriptions, subscriptionItems } from "@server/db";
import { eq, and, or } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import stripe from "#private/lib/stripe";
import {
getTier1FeaturePriceSet,
getTier3FeaturePriceSet,
getTier2FeaturePriceSet,
FeatureId,
type FeaturePriceSet
} from "@server/lib/billing";
import { getLineItems } from "@server/lib/billing/getLineItems";
const changeTierSchema = z.strictObject({
orgId: z.string()
});
const changeTierBodySchema = z.strictObject({
tier: z.enum(["tier1", "tier2", "tier3"])
});
export async function changeTier(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = changeTierSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const parsedBody = changeTierBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { tier } = parsedBody.data;
// Get the customer for this org
const [customer] = await db
.select()
.from(customers)
.where(eq(customers.orgId, orgId))
.limit(1);
if (!customer) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No customer found for this organization"
)
);
}
// Get the active subscription for this customer
const [subscription] = await db
.select()
.from(subscriptions)
.where(
and(
eq(subscriptions.customerId, customer.customerId),
eq(subscriptions.status, "active"),
or(
eq(subscriptions.type, "tier1"),
eq(subscriptions.type, "tier2"),
eq(subscriptions.type, "tier3")
)
)
)
.limit(1);
if (!subscription) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No active subscription found for this organization"
)
);
}
// Get the target tier's price set
let targetPriceSet: FeaturePriceSet;
if (tier === "tier1") {
targetPriceSet = getTier1FeaturePriceSet();
} else if (tier === "tier2") {
targetPriceSet = getTier2FeaturePriceSet();
} else if (tier === "tier3") {
targetPriceSet = getTier3FeaturePriceSet();
} else {
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid tier"));
}
// Get current subscription items from our database
const currentItems = await db
.select()
.from(subscriptionItems)
.where(
eq(
subscriptionItems.subscriptionId,
subscription.subscriptionId
)
);
if (currentItems.length === 0) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"No subscription items found"
)
);
}
// Retrieve the full subscription from Stripe to get item IDs
const stripeSubscription = await stripe!.subscriptions.retrieve(
subscription.subscriptionId
);
// Determine if we're switching between different products
// tier1 uses TIER1 product, tier2/tier3 use USERS product
const currentTier = subscription.type;
const switchingProducts =
(currentTier === "tier1" &&
(tier === "tier2" || tier === "tier3")) ||
((currentTier === "tier2" || currentTier === "tier3") &&
tier === "tier1");
let updatedSubscription;
if (switchingProducts) {
// When switching between different products, we need to:
// 1. Delete old subscription items
// 2. Add new subscription items
logger.info(
`Switching products from ${currentTier} to ${tier} for subscription ${subscription.subscriptionId}`
);
// Build array to delete all existing items and add new ones
const itemsToUpdate: any[] = [];
// Mark all existing items for deletion
for (const stripeItem of stripeSubscription.items.data) {
itemsToUpdate.push({
id: stripeItem.id,
deleted: true
});
}
// Add new items for the target tier
const newLineItems = await getLineItems(targetPriceSet, orgId);
for (const lineItem of newLineItems) {
itemsToUpdate.push(lineItem);
}
updatedSubscription = await stripe!.subscriptions.update(
subscription.subscriptionId,
{
items: itemsToUpdate,
proration_behavior: "create_prorations"
}
);
} else {
// Same product, different price tier (tier2 <-> tier3)
// We can simply update the price
logger.info(
`Updating price from ${currentTier} to ${tier} for subscription ${subscription.subscriptionId}`
);
const itemsToUpdate = stripeSubscription.items.data.map(
(stripeItem) => {
// Find the corresponding item in our database
const dbItem = currentItems.find(
(item) => item.priceId === stripeItem.price.id
);
if (!dbItem) {
// Keep the existing item unchanged if we can't find it
return {
id: stripeItem.id,
price: stripeItem.price.id,
quantity: stripeItem.quantity
};
}
// Map to the corresponding feature in the new tier
const newPriceId = targetPriceSet[FeatureId.USERS];
if (newPriceId) {
return {
id: stripeItem.id,
price: newPriceId,
quantity: stripeItem.quantity
};
}
// If no mapping found, keep existing
return {
id: stripeItem.id,
price: stripeItem.price.id,
quantity: stripeItem.quantity
};
}
);
updatedSubscription = await stripe!.subscriptions.update(
subscription.subscriptionId,
{
items: itemsToUpdate,
proration_behavior: "create_prorations"
}
);
}
logger.info(
`Successfully changed tier to ${tier} for org ${orgId}, subscription ${subscription.subscriptionId}`
);
return response<{ subscriptionId: string; newTier: string }>(res, {
data: {
subscriptionId: updatedSubscription.id,
newTier: tier
},
success: true,
error: false,
message: "Tier change successful",
status: HttpCode.OK
});
} catch (error) {
logger.error("Error changing tier:", error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred while changing tier"
)
);
}
}

View File

@@ -22,14 +22,23 @@ import logger from "@server/logger";
import config from "@server/lib/config";
import { fromError } from "zod-validation-error";
import stripe from "#private/lib/stripe";
import { getLineItems, getStandardFeaturePriceSet } from "@server/lib/billing";
import { getTierPriceSet, TierId } from "@server/lib/billing/tiers";
import {
getTier1FeaturePriceSet,
getTier3FeaturePriceSet,
getTier2FeaturePriceSet
} from "@server/lib/billing";
import { getLineItems } from "@server/lib/billing/getLineItems";
import Stripe from "stripe";
const createCheckoutSessionSchema = z.strictObject({
orgId: z.string()
});
export async function createCheckoutSessionSAAS(
const createCheckoutSessionBodySchema = z.strictObject({
tier: z.enum(["tier1", "tier2", "tier3"])
});
export async function createCheckoutSession(
req: Request,
res: Response,
next: NextFunction
@@ -47,6 +56,18 @@ export async function createCheckoutSessionSAAS(
const { orgId } = parsedParams.data;
const parsedBody = createCheckoutSessionBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { tier } = parsedBody.data;
// check if we already have a customer for this org
const [customer] = await db
.select()
@@ -65,20 +86,26 @@ export async function createCheckoutSessionSAAS(
);
}
const standardTierPrice = getTierPriceSet()[TierId.STANDARD];
let lineItems: Stripe.Checkout.SessionCreateParams.LineItem[];
if (tier === "tier1") {
lineItems = await getLineItems(getTier1FeaturePriceSet(), orgId);
} else if (tier === "tier2") {
lineItems = await getLineItems(getTier2FeaturePriceSet(), orgId);
} else if (tier === "tier3") {
lineItems = await getLineItems(getTier3FeaturePriceSet(), orgId);
} else {
return next(createHttpError(HttpCode.BAD_REQUEST, "Invalid plan"));
}
logger.debug(`Line items: ${JSON.stringify(lineItems)}`);
const session = await stripe!.checkout.sessions.create({
client_reference_id: orgId, // So we can look it up the org later on the webhook
billing_address_collection: "required",
line_items: [
{
price: standardTierPrice, // Use the standard tier
quantity: 1
},
...getLineItems(getStandardFeaturePriceSet())
], // Start with the standard feature set that matches the free limits
line_items: lineItems,
customer: customer.customerId,
mode: "subscription",
allow_promotion_codes: true,
success_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing?success=true&session_id={CHECKOUT_SESSION_ID}`,
cancel_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/billing?canceled=true`
});

View File

@@ -0,0 +1,410 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import { SubscriptionType } from "./hooks/getSubType";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { Tier } from "@server/types/Tiers";
import logger from "@server/logger";
import { db, idp, idpOrg, loginPage, loginPageBranding, loginPageBrandingOrg, loginPageOrg, orgs, resources, roles } from "@server/db";
import { eq } from "drizzle-orm";
/**
* Get the maximum allowed retention days for a given tier
* Returns null for enterprise tier (unlimited)
*/
function getMaxRetentionDaysForTier(tier: Tier | null): number | null {
if (!tier) {
return 3; // Free tier
}
switch (tier) {
case "tier1":
return 7;
case "tier2":
return 30;
case "tier3":
return 90;
case "enterprise":
return null; // No limit
default:
return 3; // Default to free tier limit
}
}
/**
* Cap retention days to the maximum allowed for the given tier
*/
async function capRetentionDays(
orgId: string,
tier: Tier | null
): Promise<void> {
const maxRetentionDays = getMaxRetentionDaysForTier(tier);
// If there's no limit (enterprise tier), no capping needed
if (maxRetentionDays === null) {
logger.debug(
`No retention day limit for org ${orgId} on tier ${tier || "free"}`
);
return;
}
// Get current org settings
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId));
if (!org) {
logger.warn(`Org ${orgId} not found when capping retention days`);
return;
}
const updates: Partial<typeof orgs.$inferInsert> = {};
let needsUpdate = false;
// Cap request log retention if it exceeds the limit
if (
org.settingsLogRetentionDaysRequest !== null &&
org.settingsLogRetentionDaysRequest > maxRetentionDays
) {
updates.settingsLogRetentionDaysRequest = maxRetentionDays;
needsUpdate = true;
logger.info(
`Capping request log retention from ${org.settingsLogRetentionDaysRequest} to ${maxRetentionDays} days for org ${orgId}`
);
}
// Cap access log retention if it exceeds the limit
if (
org.settingsLogRetentionDaysAccess !== null &&
org.settingsLogRetentionDaysAccess > maxRetentionDays
) {
updates.settingsLogRetentionDaysAccess = maxRetentionDays;
needsUpdate = true;
logger.info(
`Capping access log retention from ${org.settingsLogRetentionDaysAccess} to ${maxRetentionDays} days for org ${orgId}`
);
}
// Cap action log retention if it exceeds the limit
if (
org.settingsLogRetentionDaysAction !== null &&
org.settingsLogRetentionDaysAction > maxRetentionDays
) {
updates.settingsLogRetentionDaysAction = maxRetentionDays;
needsUpdate = true;
logger.info(
`Capping action log retention from ${org.settingsLogRetentionDaysAction} to ${maxRetentionDays} days for org ${orgId}`
);
}
// Apply updates if needed
if (needsUpdate) {
await db
.update(orgs)
.set(updates)
.where(eq(orgs.orgId, orgId));
logger.info(
`Successfully capped retention days for org ${orgId} to max ${maxRetentionDays} days`
);
} else {
logger.debug(
`No retention day capping needed for org ${orgId}`
);
}
}
export async function handleTierChange(
orgId: string,
newTier: SubscriptionType | null,
previousTier?: SubscriptionType | null
): Promise<void> {
logger.info(
`Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}`
);
// License subscriptions are handled separately and don't use the tier matrix
if (newTier === "license") {
logger.debug(
`New tier is license for org ${orgId}, no feature lifecycle handling needed`
);
return;
}
// If newTier is null, treat as free tier - disable all features
if (newTier === null) {
logger.info(
`Org ${orgId} is reverting to free tier, disabling all paid features`
);
// Cap retention days to free tier limits
await capRetentionDays(orgId, null);
// Disable all features in the tier matrix
for (const [featureKey] of Object.entries(tierMatrix)) {
const feature = featureKey as TierFeature;
logger.info(
`Feature ${feature} is not available in free tier for org ${orgId}. Disabling...`
);
await disableFeature(orgId, feature);
}
logger.info(
`Completed free tier feature lifecycle handling for org ${orgId}`
);
return;
}
// Get the tier (cast as Tier since we've ruled out "license" and null)
const tier = newTier as Tier;
// Cap retention days to the new tier's limits
await capRetentionDays(orgId, tier);
// Check each feature in the tier matrix
for (const [featureKey, allowedTiers] of Object.entries(tierMatrix)) {
const feature = featureKey as TierFeature;
const isFeatureAvailable = allowedTiers.includes(tier);
if (!isFeatureAvailable) {
logger.info(
`Feature ${feature} is not available in tier ${tier} for org ${orgId}. Disabling...`
);
await disableFeature(orgId, feature);
} else {
logger.debug(
`Feature ${feature} is available in tier ${tier} for org ${orgId}`
);
}
}
logger.info(
`Completed tier change feature lifecycle handling for org ${orgId}`
);
}
async function disableFeature(
orgId: string,
feature: TierFeature
): Promise<void> {
try {
switch (feature) {
case TierFeature.OrgOidc:
await disableOrgOidc(orgId);
break;
case TierFeature.LoginPageDomain:
await disableLoginPageDomain(orgId);
break;
case TierFeature.DeviceApprovals:
await disableDeviceApprovals(orgId);
break;
case TierFeature.LoginPageBranding:
await disableLoginPageBranding(orgId);
break;
case TierFeature.LogExport:
await disableLogExport(orgId);
break;
case TierFeature.AccessLogs:
await disableAccessLogs(orgId);
break;
case TierFeature.ActionLogs:
await disableActionLogs(orgId);
break;
case TierFeature.RotateCredentials:
await disableRotateCredentials(orgId);
break;
case TierFeature.MaintencePage:
await disableMaintencePage(orgId);
break;
case TierFeature.DevicePosture:
await disableDevicePosture(orgId);
break;
case TierFeature.TwoFactorEnforcement:
await disableTwoFactorEnforcement(orgId);
break;
case TierFeature.SessionDurationPolicies:
await disableSessionDurationPolicies(orgId);
break;
case TierFeature.PasswordExpirationPolicies:
await disablePasswordExpirationPolicies(orgId);
break;
case TierFeature.AutoProvisioning:
await disableAutoProvisioning(orgId);
break;
default:
logger.warn(
`Unknown feature ${feature} for org ${orgId}, skipping`
);
}
logger.info(
`Successfully disabled feature ${feature} for org ${orgId}`
);
} catch (error) {
logger.error(
`Error disabling feature ${feature} for org ${orgId}:`,
error
);
throw error;
}
}
async function disableOrgOidc(orgId: string): Promise<void> {}
async function disableDeviceApprovals(orgId: string): Promise<void> {
await db
.update(roles)
.set({ requireDeviceApproval: false })
.where(eq(roles.orgId, orgId));
logger.info(`Disabled device approvals on all roles for org ${orgId}`);
}
async function disableLoginPageBranding(orgId: string): Promise<void> {
const [existingBranding] = await db
.select()
.from(loginPageBrandingOrg)
.where(eq(loginPageBrandingOrg.orgId, orgId));
if (existingBranding) {
await db
.delete(loginPageBranding)
.where(
eq(
loginPageBranding.loginPageBrandingId,
existingBranding.loginPageBrandingId
)
);
logger.info(`Disabled login page branding for org ${orgId}`);
}
}
async function disableLoginPageDomain(orgId: string): Promise<void> {
const [existingLoginPage] = await db
.select()
.from(loginPageOrg)
.where(eq(loginPageOrg.orgId, orgId))
.innerJoin(
loginPage,
eq(loginPage.loginPageId, loginPageOrg.loginPageId)
);
if (existingLoginPage) {
await db
.delete(loginPageOrg)
.where(eq(loginPageOrg.orgId, orgId));
await db
.delete(loginPage)
.where(
eq(
loginPage.loginPageId,
existingLoginPage.loginPageOrg.loginPageId
)
);
logger.info(`Disabled login page domain for org ${orgId}`);
}
}
async function disableLogExport(orgId: string): Promise<void> {}
async function disableAccessLogs(orgId: string): Promise<void> {
await db
.update(orgs)
.set({ settingsLogRetentionDaysAccess: 0 })
.where(eq(orgs.orgId, orgId));
logger.info(`Disabled access logs for org ${orgId}`);
}
async function disableActionLogs(orgId: string): Promise<void> {
await db
.update(orgs)
.set({ settingsLogRetentionDaysAction: 0 })
.where(eq(orgs.orgId, orgId));
logger.info(`Disabled action logs for org ${orgId}`);
}
async function disableRotateCredentials(orgId: string): Promise<void> {}
async function disableMaintencePage(orgId: string): Promise<void> {
await db
.update(resources)
.set({
maintenanceModeEnabled: false
})
.where(eq(resources.orgId, orgId));
logger.info(`Disabled maintenance page on all resources for org ${orgId}`);
}
async function disableDevicePosture(orgId: string): Promise<void> {}
async function disableTwoFactorEnforcement(orgId: string): Promise<void> {
await db
.update(orgs)
.set({ requireTwoFactor: false })
.where(eq(orgs.orgId, orgId));
logger.info(`Disabled two-factor enforcement for org ${orgId}`);
}
async function disableSessionDurationPolicies(orgId: string): Promise<void> {
await db
.update(orgs)
.set({ maxSessionLengthHours: null })
.where(eq(orgs.orgId, orgId));
logger.info(`Disabled session duration policies for org ${orgId}`);
}
async function disablePasswordExpirationPolicies(orgId: string): Promise<void> {
await db
.update(orgs)
.set({ passwordExpiryDays: null })
.where(eq(orgs.orgId, orgId));
logger.info(`Disabled password expiration policies for org ${orgId}`);
}
async function disableAutoProvisioning(orgId: string): Promise<void> {
// Get all IDP IDs for this org through the idpOrg join table
const orgIdps = await db
.select({ idpId: idpOrg.idpId })
.from(idpOrg)
.where(eq(idpOrg.orgId, orgId));
// Update autoProvision to false for all IDPs in this org
for (const { idpId } of orgIdps) {
await db
.update(idp)
.set({ autoProvision: false })
.where(eq(idp.idpId, idpId));
}
}

View File

@@ -23,6 +23,8 @@ import logger from "@server/logger";
import { fromZodError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { GetOrgSubscriptionResponse } from "@server/routers/billing/types";
import { usageService } from "@server/lib/billing/usageService";
import { build } from "@server/build";
// Import tables for billing
import {
@@ -70,9 +72,19 @@ export async function getOrgSubscriptions(
throw err;
}
let limitsExceeded = false;
if (build === "saas") {
try {
limitsExceeded = await usageService.checkLimitSet(orgId);
} catch (err) {
logger.error("Error checking limits for org %s: %s", orgId, err);
}
}
return response<GetOrgSubscriptionResponse>(res, {
data: {
subscriptions
subscriptions,
...(build === "saas" ? { limitsExceeded } : {})
},
success: true,
error: false,

View File

@@ -78,16 +78,10 @@ export async function getOrgUsage(
// Get usage for org
const usageData = [];
const siteUptime = await usageService.getUsage(
orgId,
FeatureId.SITE_UPTIME
);
const users = await usageService.getUsageDaily(orgId, FeatureId.USERS);
const domains = await usageService.getUsageDaily(
orgId,
FeatureId.DOMAINS
);
const remoteExitNodes = await usageService.getUsageDaily(
const sites = await usageService.getUsage(orgId, FeatureId.SITES);
const users = await usageService.getUsage(orgId, FeatureId.USERS);
const domains = await usageService.getUsage(orgId, FeatureId.DOMAINS);
const remoteExitNodes = await usageService.getUsage(
orgId,
FeatureId.REMOTE_EXIT_NODES
);
@@ -96,8 +90,8 @@ export async function getOrgUsage(
FeatureId.EGRESS_DATA_MB
);
if (siteUptime) {
usageData.push(siteUptime);
if (sites) {
usageData.push(sites);
}
if (users) {
usageData.push(users);

View File

@@ -1,35 +1,62 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/
import {
getLicensePriceSet,
} from "@server/lib/billing/licenses";
import {
getTierPriceSet,
} from "@server/lib/billing/tiers";
getTier1FeaturePriceSet,
getTier2FeaturePriceSet,
getTier3FeaturePriceSet,
} from "@server/lib/billing/features";
import Stripe from "stripe";
import { Tier } from "@server/types/Tiers";
export function getSubType(fullSubscription: Stripe.Response<Stripe.Subscription>): "saas" | "license" {
export type SubscriptionType = Tier | "license";
export function getSubType(fullSubscription: Stripe.Response<Stripe.Subscription>): SubscriptionType | null {
// Determine subscription type by checking subscription items
let type: "saas" | "license" = "saas";
if (Array.isArray(fullSubscription.items?.data)) {
for (const item of fullSubscription.items.data) {
const priceId = item.price.id;
if (!Array.isArray(fullSubscription.items?.data) || fullSubscription.items.data.length === 0) {
return null;
}
// Check if price ID matches any license price
const licensePrices = Object.values(getLicensePriceSet());
for (const item of fullSubscription.items.data) {
const priceId = item.price.id;
if (licensePrices.includes(priceId)) {
type = "license";
break;
}
// Check if price ID matches any license price
const licensePrices = Object.values(getLicensePriceSet());
if (licensePrices.includes(priceId)) {
return "license";
}
// Check if price ID matches any tier price (saas)
const tierPrices = Object.values(getTierPriceSet());
// Check if price ID matches home lab tier
const homeLabPrices = Object.values(getTier1FeaturePriceSet());
if (homeLabPrices.includes(priceId)) {
return "tier1";
}
if (tierPrices.includes(priceId)) {
type = "saas";
break;
}
// Check if price ID matches tier2 tier
const tier2Prices = Object.values(getTier2FeaturePriceSet());
if (tier2Prices.includes(priceId)) {
return "tier2";
}
// Check if price ID matches tier3 tier
const tier3Prices = Object.values(getTier3FeaturePriceSet());
if (tier3Prices.includes(priceId)) {
return "tier3";
}
}
return type;
return null;
}

View File

@@ -31,6 +31,8 @@ import { getLicensePriceSet, LicenseId } from "@server/lib/billing/licenses";
import { sendEmail } from "@server/emails";
import EnterpriseEditionKeyGenerated from "@server/emails/templates/EnterpriseEditionKeyGenerated";
import config from "@server/lib/config";
import { getFeatureIdByPriceId } from "@server/lib/billing/features";
import { handleTierChange } from "../featureLifecycle";
export async function handleSubscriptionCreated(
subscription: Stripe.Subscription
@@ -59,6 +61,8 @@ export async function handleSubscriptionCreated(
return;
}
const type = getSubType(fullSubscription);
const newSubscription = {
subscriptionId: subscription.id,
customerId: subscription.customer as string,
@@ -66,7 +70,9 @@ export async function handleSubscriptionCreated(
canceledAt: subscription.canceled_at
? subscription.canceled_at
: null,
createdAt: subscription.created
createdAt: subscription.created,
type: type,
version: 1 // we are hardcoding the initial version when the subscription is created, and then we will increment it on every update
};
await db.insert(subscriptions).values(newSubscription);
@@ -87,10 +93,15 @@ export async function handleSubscriptionCreated(
name = product.name || null;
}
// Get the feature ID from the price ID
const featureId = getFeatureIdByPriceId(item.price.id);
return {
stripeSubscriptionItemId: item.id,
subscriptionId: subscription.id,
planId: item.plan.id,
priceId: item.price.id,
featureId: featureId || null,
meterId: item.plan.meter,
unitAmount: item.price.unit_amount || 0,
currentPeriodStart: item.current_period_start,
@@ -129,17 +140,23 @@ export async function handleSubscriptionCreated(
return;
}
const type = getSubType(fullSubscription);
if (type === "saas") {
if (type === "tier1" || type === "tier2" || type === "tier3") {
logger.debug(
`Handling SAAS subscription lifecycle for org ${customer.orgId}`
`Handling SAAS subscription lifecycle for org ${customer.orgId} with type ${type}`
);
// we only need to handle the limit lifecycle for saas subscriptions not for the licenses
await handleSubscriptionLifesycle(
customer.orgId,
subscription.status
subscription.status,
type
);
// Handle initial tier setup - disable features not available in this tier
logger.info(
`Setting up initial tier features for org ${customer.orgId} with type ${type}`
);
await handleTierChange(customer.orgId, type);
const [orgUserRes] = await db
.select()
.from(userOrgs)

View File

@@ -27,6 +27,7 @@ import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
import { getSubType } from "./getSubType";
import stripe from "#private/lib/stripe";
import privateConfig from "#private/lib/config";
import { handleTierChange } from "../featureLifecycle";
export async function handleSubscriptionDeleted(
subscription: Stripe.Subscription
@@ -76,16 +77,23 @@ export async function handleSubscriptionDeleted(
}
const type = getSubType(fullSubscription);
if (type === "saas") {
if (type == "tier1" || type == "tier2" || type == "tier3") {
logger.debug(
`Handling SaaS subscription deletion for orgId ${customer.orgId} and subscription ID ${subscription.id}`
);
await handleSubscriptionLifesycle(
customer.orgId,
subscription.status
subscription.status,
type
);
// Handle feature lifecycle for cancellation - disable all tier-specific features
logger.info(
`Disabling tier-specific features for org ${customer.orgId} due to subscription deletion`
);
await handleTierChange(customer.orgId, null, type);
const [orgUserRes] = await db
.select()
.from(userOrgs)

View File

@@ -23,11 +23,12 @@ import {
} from "@server/db";
import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import { getFeatureIdByMetricId } from "@server/lib/billing/features";
import { getFeatureIdByMetricId, getFeatureIdByPriceId } from "@server/lib/billing/features";
import stripe from "#private/lib/stripe";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
import { getSubType } from "./getSubType";
import { getSubType, SubscriptionType } from "./getSubType";
import privateConfig from "#private/lib/config";
import { handleTierChange } from "../featureLifecycle";
export async function handleSubscriptionUpdated(
subscription: Stripe.Subscription,
@@ -64,6 +65,9 @@ export async function handleSubscriptionUpdated(
.where(eq(customers.customerId, subscription.customer as string))
.limit(1);
const type = getSubType(fullSubscription);
const previousType = existingSubscription.type as SubscriptionType | null;
await db
.update(subscriptions)
.set({
@@ -72,25 +76,55 @@ export async function handleSubscriptionUpdated(
? subscription.canceled_at
: null,
updatedAt: Math.floor(Date.now() / 1000),
billingCycleAnchor: subscription.billing_cycle_anchor
billingCycleAnchor: subscription.billing_cycle_anchor,
type: type
})
.where(eq(subscriptions.subscriptionId, subscription.id));
// Handle tier change if the subscription type changed
if (type && type !== previousType) {
logger.info(
`Tier change detected for org ${customer.orgId}: ${previousType} -> ${type}`
);
await handleTierChange(customer.orgId, type, previousType ?? undefined);
}
// Upsert subscription items
if (Array.isArray(fullSubscription.items?.data)) {
const itemsToUpsert = fullSubscription.items.data.map((item) => ({
subscriptionId: subscription.id,
planId: item.plan.id,
priceId: item.price.id,
meterId: item.plan.meter,
unitAmount: item.price.unit_amount || 0,
currentPeriodStart: item.current_period_start,
currentPeriodEnd: item.current_period_end,
tiers: item.price.tiers
? JSON.stringify(item.price.tiers)
: null,
interval: item.plan.interval
}));
// First, get existing items to preserve featureId when there's no match
const existingItems = await db
.select()
.from(subscriptionItems)
.where(eq(subscriptionItems.subscriptionId, subscription.id));
const itemsToUpsert = fullSubscription.items.data.map((item) => {
// Try to get featureId from price
let featureId: string | null = getFeatureIdByPriceId(item.price.id) || null;
// If no match, try to preserve existing featureId
if (!featureId) {
const existingItem = existingItems.find(
(ei) => ei.stripeSubscriptionItemId === item.id
);
featureId = existingItem?.featureId || null;
}
return {
stripeSubscriptionItemId: item.id,
subscriptionId: subscription.id,
planId: item.plan.id,
priceId: item.price.id,
featureId: featureId,
meterId: item.plan.meter,
unitAmount: item.price.unit_amount || 0,
currentPeriodStart: item.current_period_start,
currentPeriodEnd: item.current_period_end,
tiers: item.price.tiers
? JSON.stringify(item.price.tiers)
: null,
interval: item.plan.interval
};
});
if (itemsToUpsert.length > 0) {
await db.transaction(async (trx) => {
await trx
@@ -154,7 +188,7 @@ export async function handleSubscriptionUpdated(
const orgId = customer.orgId;
if (!orgId) {
logger.warn(
logger.debug(
`No orgId found in subscription metadata for subscription ${subscription.id}. Skipping usage reset.`
);
continue;
@@ -234,17 +268,29 @@ export async function handleSubscriptionUpdated(
}
// --- end usage update ---
const type = getSubType(fullSubscription);
if (type === "saas") {
if (type === "tier1" || type === "tier2" || type === "tier3") {
logger.debug(
`Handling SAAS subscription lifecycle for org ${customer.orgId}`
`Handling SAAS subscription lifecycle for org ${customer.orgId} with type ${type}`
);
// we only need to handle the limit lifecycle for saas subscriptions not for the licenses
await handleSubscriptionLifesycle(
customer.orgId,
subscription.status
subscription.status,
type
);
} else {
// Handle feature lifecycle when subscription is canceled or becomes unpaid
if (
subscription.status === "canceled" ||
subscription.status === "unpaid" ||
subscription.status === "incomplete_expired"
) {
logger.info(
`Subscription ${subscription.id} for org ${customer.orgId} is ${subscription.status}, disabling paid features`
);
await handleTierChange(customer.orgId, null, previousType ?? undefined);
}
} else if (type === "license") {
if (subscription.status === "canceled" || subscription.status == "unpaid" || subscription.status == "incomplete_expired") {
try {
// WARNING:

View File

@@ -11,8 +11,9 @@
* This file is not licensed under the AGPLv3.
*/
export * from "./createCheckoutSessionSAAS";
export * from "./createCheckoutSession";
export * from "./createPortalSession";
export * from "./getOrgSubscriptions";
export * from "./getOrgUsage";
export * from "./internalGetOrgTier";
export * from "./changeTier";

View File

@@ -13,38 +13,66 @@
import {
freeLimitSet,
tier1LimitSet,
tier2LimitSet,
tier3LimitSet,
limitsService,
subscribedLimitSet
LimitSet
} from "@server/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import logger from "@server/logger";
import { SubscriptionType } from "./hooks/getSubType";
function getLimitSetForSubscriptionType(
subType: SubscriptionType | null
): LimitSet {
switch (subType) {
case "tier1":
return tier1LimitSet;
case "tier2":
return tier2LimitSet;
case "tier3":
return tier3LimitSet;
case "license":
// License subscriptions use tier2 limits by default
// This can be adjusted based on your business logic
return tier2LimitSet;
default:
return freeLimitSet;
}
}
export async function handleSubscriptionLifesycle(
orgId: string,
status: string
status: string,
subType: SubscriptionType | null
) {
switch (status) {
case "active":
await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet);
await usageService.checkLimitSet(orgId, true);
const activeLimitSet = getLimitSetForSubscriptionType(subType);
await limitsService.applyLimitSetToOrg(orgId, activeLimitSet);
await usageService.checkLimitSet(orgId);
break;
case "canceled":
// Subscription canceled - revert to free tier
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
await usageService.checkLimitSet(orgId, true);
await usageService.checkLimitSet(orgId);
break;
case "past_due":
// Optionally handle past due status, e.g., notify customer
// Payment past due - keep current limits but notify customer
// Limits will revert to free tier if it becomes unpaid
break;
case "unpaid":
// Subscription unpaid - revert to free tier
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
await usageService.checkLimitSet(orgId, true);
await usageService.checkLimitSet(orgId);
break;
case "incomplete":
// Optionally handle incomplete status, e.g., notify customer
// Payment incomplete - give them time to complete payment
break;
case "incomplete_expired":
// Payment never completed - revert to free tier
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
await usageService.checkLimitSet(orgId, true);
await usageService.checkLimitSet(orgId);
break;
default:
break;

View File

@@ -31,7 +31,8 @@ import {
verifyUserHasAction,
verifyUserIsServerAdmin,
verifySiteAccess,
verifyClientAccess
verifyClientAccess,
verifyLimits
} from "@server/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import {
@@ -52,6 +53,7 @@ import {
authenticated as a,
authRouter as aa
} from "@server/routers/external";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
export const authenticated = a;
export const unauthenticated = ua;
@@ -76,7 +78,9 @@ unauthenticated.post(
authenticated.put(
"/org/:orgId/idp/oidc",
verifyValidLicense,
verifyValidSubscription(tierMatrix.orgOidc),
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createIdp),
logActionAudit(ActionsEnum.createIdp),
orgIdp.createOrgOidcIdp
@@ -85,8 +89,10 @@ authenticated.put(
authenticated.post(
"/org/:orgId/idp/:idpId/oidc",
verifyValidLicense,
verifyValidSubscription(tierMatrix.orgOidc),
verifyOrgAccess,
verifyIdpAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateIdp),
logActionAudit(ActionsEnum.updateIdp),
orgIdp.updateOrgOidcIdp
@@ -135,35 +141,27 @@ authenticated.post(
verifyValidLicense,
verifyOrgAccess,
verifyCertificateAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.restartCertificate),
logActionAudit(ActionsEnum.restartCertificate),
certificates.restartCertificate
);
if (build === "saas") {
unauthenticated.post(
"/quick-start",
rateLimit({
windowMs: 15 * 60 * 1000,
max: 100,
keyGenerator: (req) => req.path,
handler: (req, res, next) => {
const message = `We're too busy right now. Please try again later.`;
return next(
createHttpError(HttpCode.TOO_MANY_REQUESTS, message)
);
},
store: createStore()
}),
auth.quickStart
);
authenticated.post(
"/org/:orgId/billing/create-checkout-session-saas",
"/org/:orgId/billing/create-checkout-session",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.billing),
logActionAudit(ActionsEnum.billing),
billing.createCheckoutSessionSAAS
billing.createCheckoutSession
);
authenticated.post(
"/org/:orgId/billing/change-tier",
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.billing),
logActionAudit(ActionsEnum.billing),
billing.changeTier
);
authenticated.post(
@@ -243,6 +241,7 @@ authenticated.put(
"/org/:orgId/remote-exit-node",
verifyValidLicense,
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createRemoteExitNode),
logActionAudit(ActionsEnum.createRemoteExitNode),
remoteExitNode.createRemoteExitNode
@@ -286,7 +285,9 @@ authenticated.delete(
authenticated.put(
"/org/:orgId/login-page",
verifyValidLicense,
verifyValidSubscription(tierMatrix.loginPageDomain),
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.createLoginPage),
logActionAudit(ActionsEnum.createLoginPage),
loginPage.createLoginPage
@@ -295,8 +296,10 @@ authenticated.put(
authenticated.post(
"/org/:orgId/login-page/:loginPageId",
verifyValidLicense,
verifyValidSubscription(tierMatrix.loginPageDomain),
verifyOrgAccess,
verifyLoginPageAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateLoginPage),
logActionAudit(ActionsEnum.updateLoginPage),
loginPage.updateLoginPage
@@ -323,6 +326,7 @@ authenticated.get(
authenticated.get(
"/org/:orgId/approvals",
verifyValidLicense,
verifyValidSubscription(tierMatrix.deviceApprovals),
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.listApprovals),
logActionAudit(ActionsEnum.listApprovals),
@@ -339,7 +343,9 @@ authenticated.get(
authenticated.put(
"/org/:orgId/approvals/:approvalId",
verifyValidLicense,
verifyValidSubscription(tierMatrix.deviceApprovals),
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateApprovals),
logActionAudit(ActionsEnum.updateApprovals),
approval.processPendingApproval
@@ -348,6 +354,7 @@ authenticated.put(
authenticated.get(
"/org/:orgId/login-page-branding",
verifyValidLicense,
verifyValidSubscription(tierMatrix.loginPageBranding),
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.getLoginPage),
logActionAudit(ActionsEnum.getLoginPage),
@@ -357,7 +364,9 @@ authenticated.get(
authenticated.put(
"/org/:orgId/login-page-branding",
verifyValidLicense,
verifyValidSubscription(tierMatrix.loginPageBranding),
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.updateLoginPage),
logActionAudit(ActionsEnum.updateLoginPage),
loginPage.upsertLoginPageBranding
@@ -433,7 +442,7 @@ authenticated.post(
authenticated.get(
"/org/:orgId/logs/action",
verifyValidLicense,
verifyValidSubscription,
verifyValidSubscription(tierMatrix.actionLogs),
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.exportLogs),
logs.queryActionAuditLogs
@@ -442,7 +451,7 @@ authenticated.get(
authenticated.get(
"/org/:orgId/logs/action/export",
verifyValidLicense,
verifyValidSubscription,
verifyValidSubscription(tierMatrix.logExport),
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.exportLogs),
logActionAudit(ActionsEnum.exportLogs),
@@ -452,7 +461,7 @@ authenticated.get(
authenticated.get(
"/org/:orgId/logs/access",
verifyValidLicense,
verifyValidSubscription,
verifyValidSubscription(tierMatrix.accessLogs),
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.exportLogs),
logs.queryAccessAuditLogs
@@ -461,7 +470,7 @@ authenticated.get(
authenticated.get(
"/org/:orgId/logs/access/export",
verifyValidLicense,
verifyValidSubscription,
verifyValidSubscription(tierMatrix.logExport),
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.exportLogs),
logActionAudit(ActionsEnum.exportLogs),
@@ -470,18 +479,20 @@ authenticated.get(
authenticated.post(
"/re-key/:clientId/regenerate-client-secret",
verifyClientAccess, // this is first to set the org id
verifyValidLicense,
verifyValidSubscription,
verifyValidSubscription(tierMatrix.rotateCredentials),
verifyClientAccess, // this is first to set the org id
verifyLimits,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateClientSecret
);
authenticated.post(
"/re-key/:siteId/regenerate-site-secret",
verifySiteAccess, // this is first to set the org id
verifyValidLicense,
verifyValidSubscription,
verifyValidSubscription(tierMatrix.rotateCredentials),
verifySiteAccess, // this is first to set the org id
verifyLimits,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateSiteSecret
);
@@ -489,8 +500,9 @@ authenticated.post(
authenticated.put(
"/re-key/:orgId/regenerate-remote-exit-node-secret",
verifyValidLicense,
verifyValidSubscription,
verifyValidSubscription(tierMatrix.rotateCredentials),
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateExitNodeSecret
);

View File

@@ -134,6 +134,7 @@ export async function generateNewEnterpriseLicense(
], // Start with the standard feature set that matches the free limits
customer: customer.customerId,
mode: "subscription",
allow_promotion_codes: true,
success_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/license?success=true&session_id={CHECKOUT_SESSION_ID}`,
cancel_url: `${config.getRawConfig().app.dashboard_url}/${orgId}/settings/license?canceled=true`
});

View File

@@ -19,21 +19,20 @@ import {
verifyApiKeyHasAction,
verifyApiKeyIsRoot,
verifyApiKeyOrgAccess,
verifyApiKeyIdpAccess
verifyApiKeyIdpAccess,
verifyLimits
} from "@server/middlewares";
import {
verifyValidSubscription,
verifyValidLicense
} from "#private/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import {
unauthenticated as ua,
authenticated as a
} from "@server/routers/integration";
import { logActionAudit } from "#private/middlewares";
import config from "#private/lib/config";
import { build } from "@server/build";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
export const unauthenticated = ua;
export const authenticated = a;
@@ -57,7 +56,7 @@ authenticated.delete(
authenticated.get(
"/org/:orgId/logs/action",
verifyValidLicense,
verifyValidSubscription,
verifyValidSubscription(tierMatrix.actionLogs),
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs),
logs.queryActionAuditLogs
@@ -66,7 +65,7 @@ authenticated.get(
authenticated.get(
"/org/:orgId/logs/action/export",
verifyValidLicense,
verifyValidSubscription,
verifyValidSubscription(tierMatrix.logExport),
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs),
logActionAudit(ActionsEnum.exportLogs),
@@ -76,7 +75,7 @@ authenticated.get(
authenticated.get(
"/org/:orgId/logs/access",
verifyValidLicense,
verifyValidSubscription,
verifyValidSubscription(tierMatrix.accessLogs),
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs),
logs.queryAccessAuditLogs
@@ -85,7 +84,7 @@ authenticated.get(
authenticated.get(
"/org/:orgId/logs/access/export",
verifyValidLicense,
verifyValidSubscription,
verifyValidSubscription(tierMatrix.logExport),
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.exportLogs),
logActionAudit(ActionsEnum.exportLogs),
@@ -95,7 +94,9 @@ authenticated.get(
authenticated.put(
"/org/:orgId/idp/oidc",
verifyValidLicense,
verifyValidSubscription(tierMatrix.orgOidc),
verifyApiKeyOrgAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.createIdp),
logActionAudit(ActionsEnum.createIdp),
orgIdp.createOrgOidcIdp
@@ -104,8 +105,10 @@ authenticated.put(
authenticated.post(
"/org/:orgId/idp/:idpId/oidc",
verifyValidLicense,
verifyValidSubscription(tierMatrix.orgOidc),
verifyApiKeyOrgAccess,
verifyApiKeyIdpAccess,
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.updateIdp),
logActionAudit(ActionsEnum.updateIdp),
orgIdp.updateOrgOidcIdp

View File

@@ -30,9 +30,7 @@ import { fromError } from "zod-validation-error";
import { eq, and } from "drizzle-orm";
import { validateAndConstructDomain } from "@server/lib/domainUtils";
import { createCertificate } from "#private/routers/certificates/createCertificate";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
import { CreateLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z.strictObject({
@@ -76,19 +74,6 @@ export async function createLoginPage(
const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existing] = await db
.select()
.from(loginPageOrg)

View File

@@ -25,9 +25,7 @@ import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq } from "drizzle-orm";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
const paramsSchema = z
.object({
@@ -53,18 +51,6 @@ export async function deleteLoginPageBranding(
const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existingLoginPageBranding] = await db
.select()

View File

@@ -25,9 +25,7 @@ import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq } from "drizzle-orm";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
const paramsSchema = z.strictObject({
orgId: z.string()
@@ -51,19 +49,6 @@ export async function getLoginPageBranding(
const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existingLoginPageBranding] = await db
.select()
.from(loginPageBranding)

View File

@@ -23,9 +23,7 @@ import { eq, and } from "drizzle-orm";
import { validateAndConstructDomain } from "@server/lib/domainUtils";
import { subdomainSchema } from "@server/lib/schemas";
import { createCertificate } from "#private/routers/certificates/createCertificate";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
import { UpdateLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z
@@ -87,18 +85,6 @@ export async function updateLoginPage(
const { loginPageId, orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
const [existingLoginPage] = await db
.select()

View File

@@ -25,10 +25,8 @@ import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, InferInsertModel } from "drizzle-orm";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { build } from "@server/build";
import config from "@server/private/lib/config";
import config from "#private/lib/config";
const paramsSchema = z.strictObject({
orgId: z.string()
@@ -128,19 +126,6 @@ export async function upsertLoginPageBranding(
const { orgId } = parsedParams.data;
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
let updateData = parsedBody.data satisfies InferInsertModel<
typeof loginPageBranding
>;

View File

@@ -24,10 +24,10 @@ import { idp, idpOidcConfig, idpOrg, orgs } from "@server/db";
import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
import { encrypt } from "@server/lib/crypto";
import config from "@server/lib/config";
import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types";
import { isSubscribed } from "#private/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import privateConfig from "#private/lib/config";
const paramsSchema = z.strictObject({ orgId: z.string().nonempty() });
@@ -93,6 +93,18 @@ export async function createOrgOidcIdp(
);
}
if (
privateConfig.getRawPrivateConfig().app.identity_provider_mode !==
"org"
) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Organization-specific IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'org' in the private configuration to enable this feature."
)
);
}
const {
clientId,
clientSecret,
@@ -103,23 +115,19 @@ export async function createOrgOidcIdp(
emailPath,
namePath,
name,
autoProvision,
variant,
roleMapping,
tags
} = parsedBody.data;
if (build === "saas") {
const { tier, active } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
let { autoProvision } = parsedBody.data;
const subscribed = await isSubscribed(
orgId,
tierMatrix.deviceApprovals
);
if (!subscribed) {
autoProvision = false;
}
const key = config.getRawConfig().server.secret!;

View File

@@ -22,6 +22,7 @@ import { fromError } from "zod-validation-error";
import { idp, idpOidcConfig, idpOrg } from "@server/db";
import { eq } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import privateConfig from "#private/lib/config";
const paramsSchema = z
.object({
@@ -59,6 +60,18 @@ export async function deleteOrgIdp(
const { idpId } = parsedParams.data;
if (
privateConfig.getRawPrivateConfig().app.identity_provider_mode !==
"org"
) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Organization-specific IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'org' in the private configuration to enable this feature."
)
);
}
// Check if IDP exists
const [existingIdp] = await db
.select()

View File

@@ -24,9 +24,9 @@ import { idp, idpOidcConfig } from "@server/db";
import { eq, and } from "drizzle-orm";
import { encrypt } from "@server/lib/crypto";
import config from "@server/lib/config";
import { build } from "@server/build";
import { getOrgTierData } from "#private/lib/billing";
import { TierId } from "@server/lib/billing/tiers";
import { isSubscribed } from "#private/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import privateConfig from "#private/lib/config";
const paramsSchema = z
.object({
@@ -98,6 +98,18 @@ export async function updateOrgOidcIdp(
);
}
if (
privateConfig.getRawPrivateConfig().app.identity_provider_mode !==
"org"
) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Organization-specific IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'org' in the private configuration to enable this feature."
)
);
}
const { idpId, orgId } = parsedParams.data;
const {
clientId,
@@ -109,22 +121,18 @@ export async function updateOrgOidcIdp(
emailPath,
namePath,
name,
autoProvision,
roleMapping,
tags
} = parsedBody.data;
if (build === "saas") {
const { tier, active } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
let { autoProvision } = parsedBody.data;
const subscribed = await isSubscribed(
orgId,
tierMatrix.deviceApprovals
);
if (!subscribed) {
autoProvision = false;
}
// Check if IDP exists and is of type OIDC

View File

@@ -85,7 +85,7 @@ export async function createRemoteExitNode(
if (usage) {
const rejectRemoteExitNodes = await usageService.checkLimitSet(
orgId,
false,
FeatureId.REMOTE_EXIT_NODES,
{
...usage,
@@ -97,7 +97,7 @@ export async function createRemoteExitNode(
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Remote exit node limit exceeded. Please upgrade your plan or contact us at support@pangolin.net"
"Remote node limit exceeded. Please upgrade your plan."
)
);
}
@@ -224,7 +224,7 @@ export async function createRemoteExitNode(
});
if (numExitNodeOrgs) {
await usageService.updateDaily(
await usageService.updateCount(
orgId,
FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length

View File

@@ -106,7 +106,7 @@ export async function deleteRemoteExitNode(
});
if (numExitNodeOrgs) {
await usageService.updateDaily(
await usageService.updateCount(
orgId,
FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length