From 4d6240c987332b261b4fe53b5839d68ad0205a1b Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 17 Feb 2026 17:09:48 -0800 Subject: [PATCH] Handle new usage tracking with multi org --- server/lib/billing/features.ts | 3 + server/lib/billing/limitSet.ts | 17 +- server/lib/billing/usageService.ts | 183 +--------------- server/lib/createUserAccountOrg.ts | 206 ------------------ server/lib/deleteOrg.ts | 27 +-- server/private/routers/billing/getOrgUsage.ts | 17 +- .../remoteExitNode/createRemoteExitNode.ts | 71 ++++-- .../remoteExitNode/deleteRemoteExitNode.ts | 70 +++--- server/routers/domain/createOrgDomain.ts | 14 +- server/routers/domain/deleteOrgDomain.ts | 15 +- server/routers/org/createOrg.ts | 80 +++++-- server/routers/site/createSite.ts | 16 +- server/routers/site/deleteSite.ts | 15 +- server/routers/user/acceptInvite.ts | 68 ++++-- server/routers/user/createOrgUser.ts | 64 ++++-- server/routers/user/removeUserOrg.ts | 87 ++++++-- .../settings/(private)/billing/page.tsx | 63 +++++- 17 files changed, 432 insertions(+), 584 deletions(-) delete mode 100644 server/lib/createUserAccountOrg.ts diff --git a/server/lib/billing/features.ts b/server/lib/billing/features.ts index 3fec53b4..6063b470 100644 --- a/server/lib/billing/features.ts +++ b/server/lib/billing/features.ts @@ -4,6 +4,7 @@ export enum FeatureId { EGRESS_DATA_MB = "egressDataMb", DOMAINS = "domains", REMOTE_EXIT_NODES = "remoteExitNodes", + ORGINIZATIONS = "organizations", TIER1 = "tier1" } @@ -19,6 +20,8 @@ export async function getFeatureDisplayName(featureId: FeatureId): Promise; -export const sandboxLimitSet: LimitSet = { - [FeatureId.USERS]: { value: 1, description: "Sandbox limit" }, - [FeatureId.SITES]: { value: 1, description: "Sandbox limit" }, - [FeatureId.DOMAINS]: { value: 0, description: "Sandbox limit" }, - [FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" }, -}; - export const freeLimitSet: LimitSet = { [FeatureId.SITES]: { value: 5, description: "Basic limit" }, [FeatureId.USERS]: { value: 5, description: "Basic limit" }, [FeatureId.DOMAINS]: { value: 5, description: "Basic limit" }, [FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Basic limit" }, + [FeatureId.ORGINIZATIONS]: { value: 1, description: "Basic limit" }, }; export const tier1LimitSet: LimitSet = { @@ -26,6 +20,7 @@ export const tier1LimitSet: LimitSet = { [FeatureId.SITES]: { value: 10, description: "Home limit" }, [FeatureId.DOMAINS]: { value: 10, description: "Home limit" }, [FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Home limit" }, + [FeatureId.ORGINIZATIONS]: { value: 1, description: "Home limit" }, }; export const tier2LimitSet: LimitSet = { @@ -45,6 +40,10 @@ export const tier2LimitSet: LimitSet = { value: 3, description: "Team limit" }, + [FeatureId.ORGINIZATIONS]: { + value: 1, + description: "Team limit" + } }; export const tier3LimitSet: LimitSet = { @@ -64,4 +63,8 @@ export const tier3LimitSet: LimitSet = { value: 20, description: "Business limit" }, + [FeatureId.ORGINIZATIONS]: { + value: 20, + description: "Business limit" + }, }; diff --git a/server/lib/billing/usageService.ts b/server/lib/billing/usageService.ts index 03fa42a8..cde6cd2a 100644 --- a/server/lib/billing/usageService.ts +++ b/server/lib/billing/usageService.ts @@ -1,12 +1,8 @@ import { eq, sql, and } from "drizzle-orm"; -import { v4 as uuidv4 } from "uuid"; -import { PutObjectCommand } from "@aws-sdk/client-s3"; import { db, usage, customers, - sites, - newts, limits, Usage, Limit, @@ -15,21 +11,9 @@ import { } from "@server/db"; import { FeatureId, getFeatureMeterId } from "./features"; import logger from "@server/logger"; -import { sendToClient } from "#dynamic/routers/ws"; import { build } from "@server/build"; -import { s3Client } from "@server/lib/s3"; import cache from "@server/lib/cache"; -interface StripeEvent { - identifier?: string; - timestamp: number; - event_name: string; - payload: { - value: number; - stripe_customer_id: string; - }; -} - export function noop() { if (build !== "saas") { return true; @@ -38,41 +22,11 @@ export function noop() { } export class UsageService { - // private bucketName: string | undefined; - // private events: StripeEvent[] = []; - // private lastUploadTime: number = Date.now(); - // private isUploading: boolean = false; constructor() { if (noop()) { return; } - - // this.bucketName = process.env.S3_BUCKET || undefined; - - // // Periodically check and upload events - // setInterval(() => { - // this.checkAndUploadEvents().catch((err) => { - // logger.error("Error in periodic event upload:", err); - // }); - // }, 30000); // every 30 seconds - - // // Handle graceful shutdown on SIGTERM - // process.on("SIGTERM", async () => { - // logger.info( - // "SIGTERM received, uploading events before shutdown..." - // ); - // await this.forceUpload(); - // logger.info("Events uploaded, proceeding with shutdown"); - // }); - - // // Handle SIGINT as well (Ctrl+C) - // process.on("SIGINT", async () => { - // logger.info("SIGINT received, uploading events before shutdown..."); - // await this.forceUpload(); - // logger.info("Events uploaded, proceeding with shutdown"); - // process.exit(0); - // }); } /** @@ -103,16 +57,6 @@ export class UsageService { while (attempt <= maxRetries) { try { - // // Get subscription data for this org (with caching) - // const customerId = await this.getCustomerId(orgIdToUse, featureId); - - // if (!customerId) { - // logger.warn( - // `No subscription data found for org ${orgIdToUse} and feature ${featureId}` - // ); - // return null; - // } - let usage; if (transaction) { usage = await this.internalAddUsage( @@ -132,11 +76,6 @@ export class UsageService { }); } - // Log event for Stripe - // if (privateConfig.getRawPrivateConfig().flags.usage_reporting) { - // await this.logStripeEvent(featureId, value, customerId); - // } - return usage || null; } catch (error: any) { // Check if this is a deadlock error @@ -191,13 +130,14 @@ export class UsageService { featureId, orgId, meterId, + instantaneousValue: value, latestValue: value, updatedAt: Math.floor(Date.now() / 1000) }) .onConflictDoUpdate({ target: usage.usageId, set: { - latestValue: sql`${usage.latestValue} + ${value}` + instantaneousValue: sql`${usage.instantaneousValue} + ${value}` } }) .returning(); @@ -228,17 +168,6 @@ export class UsageService { let orgIdToUse = await this.getBillingOrg(orgId); try { - // if (!customerId) { - // customerId = - // (await this.getCustomerId(orgIdToUse, featureId)) || undefined; - // if (!customerId) { - // logger.warn( - // `No subscription data found for org ${orgIdToUse} and feature ${featureId}` - // ); - // return; - // } - // } - // Truncate value to 11 decimal places if provided if (value !== undefined && value !== null) { value = this.truncateValue(value); @@ -523,114 +452,6 @@ export class UsageService { return hasExceededLimits; } - - // private async logStripeEvent( - // featureId: FeatureId, - // value: number, - // customerId: string - // ): Promise { - // // Truncate value to 11 decimal places before sending to Stripe - // const truncatedValue = this.truncateValue(value); - - // const event: StripeEvent = { - // identifier: uuidv4(), - // timestamp: Math.floor(new Date().getTime() / 1000), - // event_name: featureId, - // payload: { - // value: truncatedValue, - // stripe_customer_id: customerId - // } - // }; - - // this.addEventToMemory(event); - // await this.checkAndUploadEvents(); - // } - - // private addEventToMemory(event: StripeEvent): void { - // if (!this.bucketName) { - // logger.warn( - // "S3 bucket name is not configured, skipping event storage." - // ); - // return; - // } - // this.events.push(event); - // } - - // private async checkAndUploadEvents(): Promise { - // const now = Date.now(); - // const timeSinceLastUpload = now - this.lastUploadTime; - - // // Check if at least 1 minute has passed since last upload - // if (timeSinceLastUpload >= 60000 && this.events.length > 0) { - // await this.uploadEventsToS3(); - // } - // } - - // private async uploadEventsToS3(): Promise { - // if (!this.bucketName) { - // logger.warn( - // "S3 bucket name is not configured, skipping S3 upload." - // ); - // return; - // } - - // if (this.events.length === 0) { - // return; - // } - - // // Check if already uploading - // if (this.isUploading) { - // logger.debug("Already uploading events, skipping"); - // return; - // } - - // this.isUploading = true; - - // try { - // // Take a snapshot of current events and clear the array - // const eventsToUpload = [...this.events]; - // this.events = []; - // this.lastUploadTime = Date.now(); - - // const fileName = this.generateEventFileName(); - // const fileContent = JSON.stringify(eventsToUpload, null, 2); - - // // Upload to S3 - // const uploadCommand = new PutObjectCommand({ - // Bucket: this.bucketName, - // Key: fileName, - // Body: fileContent, - // ContentType: "application/json" - // }); - - // await s3Client.send(uploadCommand); - - // logger.info( - // `Uploaded ${fileName} to S3 with ${eventsToUpload.length} events` - // ); - // } catch (error) { - // logger.error("Failed to upload events to S3:", error); - // // Note: Events are lost if upload fails. In a production system, - // // you might want to add the events back to the array or implement retry logic - // } finally { - // this.isUploading = false; - // } - // } - - // private generateEventFileName(): string { - // const timestamp = new Date().toISOString().replace(/[:.]/g, "-"); - // const uuid = uuidv4().substring(0, 8); - // return `events-${timestamp}-${uuid}.json`; - // } - - // public async forceUpload(): Promise { - // if (this.events.length > 0) { - // // Force upload regardless of time - // this.lastUploadTime = 0; // Reset to force upload - // await this.uploadEventsToS3(); - // } - // } - } export const usageService = new UsageService(); diff --git a/server/lib/createUserAccountOrg.ts b/server/lib/createUserAccountOrg.ts deleted file mode 100644 index a40407d1..00000000 --- a/server/lib/createUserAccountOrg.ts +++ /dev/null @@ -1,206 +0,0 @@ -import { isValidCIDR } from "@server/lib/validators"; -import { getNextAvailableOrgSubnet } from "@server/lib/ip"; -import { - actions, - apiKeyOrg, - apiKeys, - db, - domains, - Org, - orgDomains, - orgs, - roleActions, - roles, - userOrgs -} from "@server/db"; -import { eq } from "drizzle-orm"; -import { defaultRoleAllowedActions } from "@server/routers/role"; -import { FeatureId, limitsService, sandboxLimitSet } from "@server/lib/billing"; -import { createCustomer } from "#dynamic/lib/billing"; -import { usageService } from "@server/lib/billing/usageService"; -import config from "@server/lib/config"; -import { generateCA } from "@server/private/lib/sshCA"; -import { encrypt } from "@server/lib/crypto"; - -export async function createUserAccountOrg( - userId: string, - userEmail: string -): Promise<{ - success: boolean; - org?: { - orgId: string; - name: string; - subnet: string; - }; - error?: string; -}> { - // const subnet = await getNextAvailableOrgSubnet(); - const orgId = "org_" + userId; - const name = `${userEmail}'s Organization`; - - // if (!isValidCIDR(subnet)) { - // return { - // success: false, - // error: "Invalid subnet format. Please provide a valid CIDR notation." - // }; - // } - - // // make sure the subnet is unique - // const subnetExists = await db - // .select() - // .from(orgs) - // .where(eq(orgs.subnet, subnet)) - // .limit(1); - - // if (subnetExists.length > 0) { - // return { success: false, error: `Subnet ${subnet} already exists` }; - // } - - // make sure the orgId is unique - const orgExists = await db - .select() - .from(orgs) - .where(eq(orgs.orgId, orgId)) - .limit(1); - - if (orgExists.length > 0) { - return { - success: false, - error: `Organization with ID ${orgId} already exists` - }; - } - - let error = ""; - let org: Org | null = null; - - await db.transaction(async (trx) => { - const allDomains = await trx - .select() - .from(domains) - .where(eq(domains.configManaged, true)); - - const utilitySubnet = config.getRawConfig().orgs.utility_subnet_group; - - // Generate SSH CA keys for the org - const ca = generateCA(`${orgId}-ca`); - const encryptionKey = config.getRawConfig().server.secret!; - const encryptedCaPrivateKey = encrypt(ca.privateKeyPem, encryptionKey); - - const newOrg = await trx - .insert(orgs) - .values({ - orgId, - name, - // subnet - subnet: "100.90.128.0/24", // TODO: this should not be hardcoded - or can it be the same in all orgs? - utilitySubnet: utilitySubnet, - createdAt: new Date().toISOString(), - sshCaPrivateKey: encryptedCaPrivateKey, - sshCaPublicKey: ca.publicKeyOpenSSH - }) - .returning(); - - if (newOrg.length === 0) { - error = "Failed to create organization"; - trx.rollback(); - return; - } - - org = newOrg[0]; - - // Create admin role within the same transaction - const [insertedRole] = await trx - .insert(roles) - .values({ - orgId: newOrg[0].orgId, - isAdmin: true, - name: "Admin", - description: "Admin role with the most permissions" - }) - .returning({ roleId: roles.roleId }); - - if (!insertedRole || !insertedRole.roleId) { - error = "Failed to create Admin role"; - trx.rollback(); - return; - } - - const roleId = insertedRole.roleId; - - // Get all actions and create role actions - const actionIds = await trx.select().from(actions).execute(); - - if (actionIds.length > 0) { - await trx.insert(roleActions).values( - actionIds.map((action) => ({ - roleId, - actionId: action.actionId, - orgId: newOrg[0].orgId - })) - ); - } - - if (allDomains.length) { - await trx.insert(orgDomains).values( - allDomains.map((domain) => ({ - orgId: newOrg[0].orgId, - domainId: domain.domainId - })) - ); - } - - await trx.insert(userOrgs).values({ - userId, - orgId: newOrg[0].orgId, - roleId: roleId, - isOwner: true - }); - - const memberRole = await trx - .insert(roles) - .values({ - name: "Member", - description: "Members can only view resources", - orgId - }) - .returning(); - - await trx.insert(roleActions).values( - defaultRoleAllowedActions.map((action) => ({ - roleId: memberRole[0].roleId, - actionId: action, - orgId - })) - ); - }); - - await limitsService.applyLimitSetToOrg(orgId, sandboxLimitSet); - - if (!org) { - return { success: false, error: "Failed to create org" }; - } - - if (error) { - return { - success: false, - error: `Failed to create org: ${error}` - }; - } - - // make sure we have the stripe customer - const customerId = await createCustomer(orgId, userEmail); - - if (customerId) { - await usageService.updateCount(orgId, FeatureId.USERS, 1, customerId); // Only 1 because we are crating the org - } - - return { - org: { - orgId, - name, - // subnet - subnet: "100.90.128.0/24" - }, - success: true - }; -} diff --git a/server/lib/deleteOrg.ts b/server/lib/deleteOrg.ts index 7295555d..856759ab 100644 --- a/server/lib/deleteOrg.ts +++ b/server/lib/deleteOrg.ts @@ -19,6 +19,8 @@ import { sendToClient } from "#dynamic/routers/ws"; import { deletePeer } from "@server/routers/gerbil/peers"; import { OlmErrorCodes } from "@server/routers/olm/error"; import { sendTerminateClient } from "@server/routers/client/terminate"; +import { usageService } from "./billing/usageService"; +import { FeatureId } from "./billing"; export type DeleteOrgByIdResult = { deletedNewtIds: string[]; @@ -74,9 +76,7 @@ export async function deleteOrgById( deletedNewtIds.push(deletedNewt.newtId); await trx .delete(newtSessions) - .where( - eq(newtSessions.newtId, deletedNewt.newtId) - ); + .where(eq(newtSessions.newtId, deletedNewt.newtId)); } } } @@ -137,6 +137,9 @@ export async function deleteOrgById( .where(inArray(domains.domainId, domainIdsToDelete)); } await trx.delete(resources).where(eq(resources.orgId, orgId)); + + await usageService.add(orgId, FeatureId.ORGINIZATIONS, -1, trx); // here we are decreasing the org count BEFORE deleting the org because we need to still be able to get the org to get the billing org inside of here + await trx.delete(orgs).where(eq(orgs.orgId, orgId)); }); @@ -155,15 +158,13 @@ export function sendTerminationMessages(result: DeleteOrgByIdResult): void { ); } for (const olmId of result.olmsToTerminate) { - sendTerminateClient( - 0, - OlmErrorCodes.TERMINATED_REKEYED, - olmId - ).catch((error) => { - logger.error( - "Failed to send termination message to olm:", - error - ); - }); + sendTerminateClient(0, OlmErrorCodes.TERMINATED_REKEYED, olmId).catch( + (error) => { + logger.error( + "Failed to send termination message to olm:", + error + ); + } + ); } } diff --git a/server/private/routers/billing/getOrgUsage.ts b/server/private/routers/billing/getOrgUsage.ts index cf4e7585..4c9f22f3 100644 --- a/server/private/routers/billing/getOrgUsage.ts +++ b/server/private/routers/billing/getOrgUsage.ts @@ -85,10 +85,14 @@ export async function getOrgUsage( orgId, FeatureId.REMOTE_EXIT_NODES ); - const egressData = await usageService.getUsage( + const organizations = await usageService.getUsage( orgId, - FeatureId.EGRESS_DATA_MB + FeatureId.ORGINIZATIONS ); + // const egressData = await usageService.getUsage( + // orgId, + // FeatureId.EGRESS_DATA_MB + // ); if (sites) { usageData.push(sites); @@ -96,15 +100,18 @@ export async function getOrgUsage( if (users) { usageData.push(users); } - if (egressData) { - usageData.push(egressData); - } + // if (egressData) { + // usageData.push(egressData); + // } if (domains) { usageData.push(domains); } if (remoteExitNodes) { usageData.push(remoteExitNodes); } + if (organizations) { + usageData.push(organizations); + } const orgLimits = await db .select() diff --git a/server/private/routers/remoteExitNode/createRemoteExitNode.ts b/server/private/routers/remoteExitNode/createRemoteExitNode.ts index 14541736..6d5b5ea6 100644 --- a/server/private/routers/remoteExitNode/createRemoteExitNode.ts +++ b/server/private/routers/remoteExitNode/createRemoteExitNode.ts @@ -12,7 +12,14 @@ */ import { NextFunction, Request, Response } from "express"; -import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg } from "@server/db"; +import { + db, + exitNodes, + exitNodeOrgs, + ExitNode, + ExitNodeOrg, + orgs +} from "@server/db"; import HttpCode from "@server/types/HttpCode"; import { z } from "zod"; import { remoteExitNodes } from "@server/db"; @@ -25,7 +32,7 @@ import { createRemoteExitNodeSession } from "#private/auth/sessions/remoteExitNo import { fromError } from "zod-validation-error"; import { hashPassword, verifyPassword } from "@server/auth/password"; import logger from "@server/logger"; -import { and, eq } from "drizzle-orm"; +import { and, eq, inArray, ne } from "drizzle-orm"; import { getNextAvailableSubnet } from "@server/lib/exitNodes"; import { usageService } from "@server/lib/billing/usageService"; import { FeatureId } from "@server/lib/billing"; @@ -169,7 +176,17 @@ export async function createRemoteExitNode( ); } - let numExitNodeOrgs: ExitNodeOrg[] | undefined; + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, orgId)) + .limit(1); + + if (!org) { + return next( + createHttpError(HttpCode.NOT_FOUND, "Organization not found") + ); + } await db.transaction(async (trx) => { if (!existingExitNode) { @@ -217,19 +234,43 @@ export async function createRemoteExitNode( }); } - numExitNodeOrgs = await trx - .select() - .from(exitNodeOrgs) - .where(eq(exitNodeOrgs.orgId, orgId)); - }); + // calculate if the node is in any other of the orgs before we count it as an add to the billing org + if (org.billingOrgId) { + const otherBillingOrgs = await trx + .select() + .from(orgs) + .where( + and( + eq(orgs.billingOrgId, org.billingOrgId), + ne(orgs.orgId, orgId) + ) + ); - if (numExitNodeOrgs) { - await usageService.updateCount( - orgId, - FeatureId.REMOTE_EXIT_NODES, - numExitNodeOrgs.length - ); - } + const billingOrgIds = otherBillingOrgs.map((o) => o.orgId); + + const orgsInBillingDomainThatTheNodeIsStillIn = await trx + .select() + .from(exitNodeOrgs) + .where( + and( + eq( + exitNodeOrgs.exitNodeId, + existingExitNode.exitNodeId + ), + inArray(exitNodeOrgs.orgId, billingOrgIds) + ) + ); + + if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) { + await usageService.add( + orgId, + FeatureId.REMOTE_EXIT_NODES, + 1, + trx + ); + } + } + }); const token = generateSessionToken(); await createRemoteExitNodeSession(token, remoteExitNodeId); diff --git a/server/private/routers/remoteExitNode/deleteRemoteExitNode.ts b/server/private/routers/remoteExitNode/deleteRemoteExitNode.ts index 8337f05d..6ff6841c 100644 --- a/server/private/routers/remoteExitNode/deleteRemoteExitNode.ts +++ b/server/private/routers/remoteExitNode/deleteRemoteExitNode.ts @@ -13,9 +13,9 @@ import { NextFunction, Request, Response } from "express"; import { z } from "zod"; -import { db, ExitNodeOrg, exitNodeOrgs, exitNodes } from "@server/db"; +import { db, ExitNodeOrg, exitNodeOrgs, exitNodes, orgs } from "@server/db"; import { remoteExitNodes } from "@server/db"; -import { and, count, eq } from "drizzle-orm"; +import { and, count, eq, inArray } from "drizzle-orm"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; @@ -50,7 +50,8 @@ export async function deleteRemoteExitNode( const [remoteExitNode] = await db .select() .from(remoteExitNodes) - .where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId)); + .where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId)) + .limit(1); if (!remoteExitNode) { return next( @@ -70,7 +71,17 @@ export async function deleteRemoteExitNode( ); } - let numExitNodeOrgs: ExitNodeOrg[] | undefined; + const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId)); + + if (!org) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Org with ID ${orgId} not found` + ) + ); + } + await db.transaction(async (trx) => { await trx .delete(exitNodeOrgs) @@ -81,38 +92,39 @@ export async function deleteRemoteExitNode( ) ); - const [remainingExitNodeOrgs] = await trx - .select({ count: count() }) - .from(exitNodeOrgs) - .where(eq(exitNodeOrgs.exitNodeId, remoteExitNode.exitNodeId!)); + // calculate if the user is in any other of the orgs before we count it as an remove to the billing org + if (org.billingOrgId) { + const otherBillingOrgs = await trx + .select() + .from(orgs) + .where(eq(orgs.billingOrgId, org.billingOrgId)); - if (remainingExitNodeOrgs.count === 0) { - await trx - .delete(remoteExitNodes) + const billingOrgIds = otherBillingOrgs.map((o) => o.orgId); + + const orgsInBillingDomainThatTheNodeIsStillIn = await trx + .select() + .from(exitNodeOrgs) .where( - eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId) + and( + eq( + exitNodeOrgs.exitNodeId, + remoteExitNode.exitNodeId! + ), + inArray(exitNodeOrgs.orgId, billingOrgIds) + ) ); - await trx - .delete(exitNodes) - .where( - eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId!) + + if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) { + await usageService.add( + orgId, + FeatureId.REMOTE_EXIT_NODES, + -1, + trx ); + } } - - numExitNodeOrgs = await trx - .select() - .from(exitNodeOrgs) - .where(eq(exitNodeOrgs.orgId, orgId)); }); - if (numExitNodeOrgs) { - await usageService.updateCount( - orgId, - FeatureId.REMOTE_EXIT_NODES, - numExitNodeOrgs.length - ); - } - return response(res, { data: null, success: true, diff --git a/server/routers/domain/createOrgDomain.ts b/server/routers/domain/createOrgDomain.ts index 35fb305f..ceb61b25 100644 --- a/server/routers/domain/createOrgDomain.ts +++ b/server/routers/domain/createOrgDomain.ts @@ -148,7 +148,6 @@ export async function createOrgDomain( } } - let numOrgDomains: OrgDomains[] | undefined; let aRecords: CreateDomainResponse["aRecords"]; let cnameRecords: CreateDomainResponse["cnameRecords"]; let txtRecords: CreateDomainResponse["txtRecords"]; @@ -347,20 +346,9 @@ export async function createOrgDomain( await trx.insert(dnsRecords).values(recordsToInsert); } - numOrgDomains = await trx - .select() - .from(orgDomains) - .where(eq(orgDomains.orgId, orgId)); + await usageService.add(orgId, FeatureId.DOMAINS, 1, trx); }); - if (numOrgDomains) { - await usageService.updateCount( - orgId, - FeatureId.DOMAINS, - numOrgDomains.length - ); - } - if (!returned) { return next( createHttpError( diff --git a/server/routers/domain/deleteOrgDomain.ts b/server/routers/domain/deleteOrgDomain.ts index 04829a13..4c347668 100644 --- a/server/routers/domain/deleteOrgDomain.ts +++ b/server/routers/domain/deleteOrgDomain.ts @@ -36,8 +36,6 @@ export async function deleteAccountDomain( } const { domainId, orgId } = parsed.data; - let numOrgDomains: OrgDomains[] | undefined; - await db.transaction(async (trx) => { const [existing] = await trx .select() @@ -79,20 +77,9 @@ export async function deleteAccountDomain( await trx.delete(domains).where(eq(domains.domainId, domainId)); - numOrgDomains = await trx - .select() - .from(orgDomains) - .where(eq(orgDomains.orgId, orgId)); + await usageService.add(orgId, FeatureId.DOMAINS, -1, trx); }); - if (numOrgDomains) { - await usageService.updateCount( - orgId, - FeatureId.DOMAINS, - numOrgDomains.length - ); - } - return response(res, { data: { success: true }, success: true, diff --git a/server/routers/org/createOrg.ts b/server/routers/org/createOrg.ts index 0c135e5b..45771492 100644 --- a/server/routers/org/createOrg.ts +++ b/server/routers/org/createOrg.ts @@ -1,7 +1,7 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; import { db } from "@server/db"; -import { and, eq } from "drizzle-orm"; +import { and, count, eq } from "drizzle-orm"; import { domains, Org, @@ -24,11 +24,7 @@ import { OpenAPITags, registry } from "@server/openApi"; import { isValidCIDR } from "@server/lib/validators"; import { createCustomer } from "#dynamic/lib/billing"; import { usageService } from "@server/lib/billing/usageService"; -import { - FeatureId, - limitsService, - sandboxLimitSet -} from "@server/lib/billing"; +import { FeatureId, limitsService, freeLimitSet } from "@server/lib/billing"; import { build } from "@server/build"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { doCidrsOverlap } from "@server/lib/ip"; @@ -114,6 +110,7 @@ export async function createOrg( // ) // ); // } + // // make sure the orgId is unique const orgExists = await db @@ -174,8 +171,46 @@ export async function createOrg( } } + if (build == "saas") { + if (!billingOrgIdForNewOrg) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Billing org not found for user. Cannot create new organization." + ) + ); + } + + const usage = await usageService.getUsage(billingOrgIdForNewOrg, FeatureId.ORGINIZATIONS); + if (!usage) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + "No usage data found for this organization" + ) + ); + } + const rejectOrgs = await usageService.checkLimitSet( + billingOrgIdForNewOrg, + FeatureId.ORGINIZATIONS, + { + ...usage, + instantaneousValue: (usage.instantaneousValue || 0) + 1 + } // We need to add one to know if we are violating the limit + ); + if (rejectOrgs) { + return next( + createHttpError( + HttpCode.FORBIDDEN, + "Organization limit exceeded. Please upgrade your plan." + ) + ); + } + } + let error = ""; let org: Org | null = null; + let numOrgs: number | null = null; await db.transaction(async (trx) => { const allDomains = await trx @@ -184,14 +219,14 @@ export async function createOrg( .where(eq(domains.configManaged, true)); // Generate SSH CA keys for the org - const ca = generateCA(`${orgId}-ca`); - const encryptionKey = config.getRawConfig().server.secret!; - const encryptedCaPrivateKey = encrypt(ca.privateKeyPem, encryptionKey); + // const ca = generateCA(`${orgId}-ca`); + // const encryptionKey = config.getRawConfig().server.secret!; + // const encryptedCaPrivateKey = encrypt(ca.privateKeyPem, encryptionKey); const saasBillingFields = build === "saas" && req.user && isFirstOrg !== null ? isFirstOrg - ? { isBillingOrg: true as const, billingOrgId: null } + ? { isBillingOrg: true as const, billingOrgId: orgId } // if this is the first org, it becomes the billing org for itself : { isBillingOrg: false as const, billingOrgId: billingOrgIdForNewOrg @@ -206,8 +241,8 @@ export async function createOrg( subnet, utilitySubnet, createdAt: new Date().toISOString(), - sshCaPrivateKey: encryptedCaPrivateKey, - sshCaPublicKey: ca.publicKeyOpenSSH, + // sshCaPrivateKey: encryptedCaPrivateKey, + // sshCaPublicKey: ca.publicKeyOpenSSH, ...saasBillingFields }) .returning(); @@ -310,6 +345,17 @@ export async function createOrg( ); await calculateUserClientsForOrgs(ownerUserId, trx); + + if (billingOrgIdForNewOrg) { + const [numOrgsResult] = await trx + .select({ count: count() }) + .from(orgs) + .where(eq(orgs.billingOrgId, billingOrgIdForNewOrg)); // all the billable orgs including the primary org that is the billing org itself + + numOrgs = numOrgsResult.count; + } else { + numOrgs = 1; // we only have one org if there is no billing org found out + } }); if (!org) { @@ -326,7 +372,7 @@ export async function createOrg( } if (build === "saas" && isFirstOrg === true) { - await limitsService.applyLimitSetToOrg(orgId, sandboxLimitSet); + await limitsService.applyLimitSetToOrg(orgId, freeLimitSet); const customerId = await createCustomer(orgId, req.user?.email); if (customerId) { await usageService.updateCount( @@ -338,6 +384,14 @@ export async function createOrg( } } + if (numOrgs) { + usageService.updateCount( + billingOrgIdForNewOrg || orgId, + FeatureId.ORGINIZATIONS, + numOrgs + ); + } + return response(res, { data: org, success: true, diff --git a/server/routers/site/createSite.ts b/server/routers/site/createSite.ts index 797bf2ae..ea4bc3e8 100644 --- a/server/routers/site/createSite.ts +++ b/server/routers/site/createSite.ts @@ -6,7 +6,7 @@ import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; -import { eq, and } from "drizzle-orm"; +import { eq, and, count } from "drizzle-orm"; import { getUniqueSiteName } from "../../db/names"; import { addPeer } from "../gerbil/peers"; import { fromError } from "zod-validation-error"; @@ -288,7 +288,6 @@ export async function createSite( const niceId = await getUniqueSiteName(orgId); let newSite: Site | undefined; - let numSites: Site[] | undefined; await db.transaction(async (trx) => { if (type == "newt") { [newSite] = await trx @@ -443,20 +442,9 @@ export async function createSite( }); } - numSites = await trx - .select() - .from(sites) - .where(eq(sites.orgId, orgId)); + await usageService.add(orgId, FeatureId.SITES, 1, trx); }); - if (numSites) { - await usageService.updateCount( - orgId, - FeatureId.SITES, - numSites.length - ); - } - if (!newSite) { return next( createHttpError( diff --git a/server/routers/site/deleteSite.ts b/server/routers/site/deleteSite.ts index 2ce900fd..cdb9d3ba 100644 --- a/server/routers/site/deleteSite.ts +++ b/server/routers/site/deleteSite.ts @@ -64,7 +64,6 @@ export async function deleteSite( } let deletedNewtId: string | null = null; - let numSites: Site[] | undefined; await db.transaction(async (trx) => { if (site.type == "wireguard") { @@ -101,21 +100,9 @@ export async function deleteSite( } } - await trx.delete(sites).where(eq(sites.siteId, siteId)); - - numSites = await trx - .select() - .from(sites) - .where(eq(sites.orgId, site.orgId)); + await usageService.add(site.orgId, FeatureId.SITES, -1, trx); }); - if (numSites) { - await usageService.updateCount( - site.orgId, - FeatureId.SITES, - numSites.length - ); - } // Send termination message outside of transaction to prevent blocking if (deletedNewtId) { const payload = { diff --git a/server/routers/user/acceptInvite.ts b/server/routers/user/acceptInvite.ts index 74f025ae..99a609a1 100644 --- a/server/routers/user/acceptInvite.ts +++ b/server/routers/user/acceptInvite.ts @@ -1,8 +1,8 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db, UserOrg } from "@server/db"; +import { db, orgs, UserOrg } from "@server/db"; import { roles, userInvites, userOrgs, users } from "@server/db"; -import { eq } from "drizzle-orm"; +import { eq, and, inArray, ne } from "drizzle-orm"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; @@ -125,8 +125,22 @@ export async function acceptInvite( } } + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, existingInvite.orgId)) + .limit(1); + + if (!org) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Organization does not exist. Please contact an admin." + ) + ); + } + let roleId: number; - let totalUsers: UserOrg[] | undefined; // get the role to make sure it exists const existingRole = await db .select() @@ -160,25 +174,45 @@ export async function acceptInvite( await calculateUserClientsForOrgs(existingUser[0].userId, trx); - // Get the total number of users in the org now - totalUsers = await trx - .select() - .from(userOrgs) - .where(eq(userOrgs.orgId, existingInvite.orgId)); + // calculate if the user is in any other of the orgs before we count it as an add to the billing org + if (org.billingOrgId) { + const otherBillingOrgs = await trx + .select() + .from(orgs) + .where( + and( + eq(orgs.billingOrgId, org.billingOrgId), + ne(orgs.orgId, existingInvite.orgId) + ) + ); + + const billingOrgIds = otherBillingOrgs.map((o) => o.orgId); + + const orgsInBillingDomainThatTheUserIsStillIn = await trx + .select() + .from(userOrgs) + .where( + and( + eq(userOrgs.userId, existingUser[0].userId), + inArray(userOrgs.orgId, billingOrgIds) + ) + ); + + if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) { + await usageService.add( + existingInvite.orgId, + FeatureId.USERS, + 1, + trx + ); + } + } logger.debug( - `User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}. Total users in org: ${totalUsers.length}` + `User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}` ); }); - if (totalUsers) { - await usageService.updateCount( - existingInvite.orgId, - FeatureId.USERS, - totalUsers.length - ); - } - return response(res, { data: { accepted: true, orgId: existingInvite.orgId }, success: true, diff --git a/server/routers/user/createOrgUser.ts b/server/routers/user/createOrgUser.ts index d0515e71..f9cab25e 100644 --- a/server/routers/user/createOrgUser.ts +++ b/server/routers/user/createOrgUser.ts @@ -6,8 +6,8 @@ import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { OpenAPITags, registry } from "@server/openApi"; -import { db, UserOrg } from "@server/db"; -import { and, eq } from "drizzle-orm"; +import { db, orgs, UserOrg } from "@server/db"; +import { and, eq, inArray, ne } from "drizzle-orm"; import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db"; import { generateId } from "@server/auth/sessions/app"; import { usageService } from "@server/lib/billing/usageService"; @@ -151,6 +151,21 @@ export async function createOrgUser( ); } + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, orgId)) + .limit(1); + + if (!org) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + "Organization not found" + ) + ); + } + const [idpRes] = await db .select() .from(idp) @@ -172,8 +187,6 @@ export async function createOrgUser( ); } - let orgUsers: UserOrg[] | undefined; - await db.transaction(async (trx) => { const [existingUser] = await trx .select() @@ -244,22 +257,37 @@ export async function createOrgUser( .returning(); } - // List all of the users in the org - orgUsers = await trx - .select() - .from(userOrgs) - .where(eq(userOrgs.orgId, orgId)); - await calculateUserClientsForOrgs(userId, trx); - }); - if (orgUsers) { - await usageService.updateCount( - orgId, - FeatureId.USERS, - orgUsers.length - ); - } + // calculate if the user is in any other of the orgs before we count it as an add to the billing org + if (org.billingOrgId) { + const otherBillingOrgs = await trx + .select() + .from(orgs) + .where( + and( + eq(orgs.billingOrgId, org.billingOrgId), + ne(orgs.orgId, orgId) + ) + ); + + const billingOrgIds = otherBillingOrgs.map((o) => o.orgId); + + const orgsInBillingDomainThatTheUserIsStillIn = await trx + .select() + .from(userOrgs) + .where( + and( + eq(userOrgs.userId, userId), + inArray(userOrgs.orgId, billingOrgIds) + ) + ); + + if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) { + await usageService.add(orgId, FeatureId.USERS, 1, trx); + } + } + }); } else { return next( createHttpError(HttpCode.BAD_REQUEST, "User type is required") diff --git a/server/routers/user/removeUserOrg.ts b/server/routers/user/removeUserOrg.ts index 768d5fff..d90d78c0 100644 --- a/server/routers/user/removeUserOrg.ts +++ b/server/routers/user/removeUserOrg.ts @@ -1,8 +1,16 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db, resources, sites, UserOrg } from "@server/db"; +import { + db, + orgs, + resources, + siteResources, + sites, + UserOrg, + userSiteResources +} from "@server/db"; import { userOrgs, userResources, users, userSites } from "@server/db"; -import { and, count, eq, exists } from "drizzle-orm"; +import { and, count, eq, exists, inArray } from "drizzle-orm"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; @@ -50,16 +58,16 @@ export async function removeUserOrg( const { userId, orgId } = parsedParams.data; // get the user first - const user = await db + const [user] = await db .select() .from(userOrgs) .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId))); - if (!user || user.length === 0) { + if (!user) { return next(createHttpError(HttpCode.NOT_FOUND, "User not found")); } - if (user[0].isOwner) { + if (user.isOwner) { return next( createHttpError( HttpCode.BAD_REQUEST, @@ -68,7 +76,17 @@ export async function removeUserOrg( ); } - let userCount: UserOrg[] | undefined; + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, orgId)) + .limit(1); + + if (!org) { + return next( + createHttpError(HttpCode.NOT_FOUND, "Organization not found") + ); + } await db.transaction(async (trx) => { await trx @@ -97,6 +115,26 @@ export async function removeUserOrg( ) ); + await db.delete(userSiteResources).where( + and( + eq(userSiteResources.userId, userId), + exists( + db + .select() + .from(siteResources) + .where( + and( + eq( + siteResources.siteResourceId, + userSiteResources.siteResourceId + ), + eq(siteResources.orgId, orgId) + ) + ) + ) + ) + ); + await db.delete(userSites).where( and( eq(userSites.userId, userId), @@ -114,11 +152,6 @@ export async function removeUserOrg( ) ); - userCount = await trx - .select() - .from(userOrgs) - .where(eq(userOrgs.orgId, orgId)); - // if (build === "saas") { // const [rootUser] = await trx // .select() @@ -137,15 +170,31 @@ export async function removeUserOrg( // } await calculateUserClientsForOrgs(userId, trx); - }); - if (userCount) { - await usageService.updateCount( - orgId, - FeatureId.USERS, - userCount.length - ); - } + // calculate if the user is in any other of the orgs before we count it as an remove to the billing org + if (org.billingOrgId) { + const billingOrgs = await trx + .select() + .from(orgs) + .where(eq(orgs.billingOrgId, org.billingOrgId)); + + const billingOrgIds = billingOrgs.map((o) => o.orgId); + + const orgsInBillingDomainThatTheUserIsStillIn = await trx + .select() + .from(userOrgs) + .where( + and( + eq(userOrgs.userId, userId), + inArray(userOrgs.orgId, billingOrgIds) + ) + ); + + if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) { + await usageService.add(orgId, FeatureId.USERS, -1, trx); + } + } + }); return response(res, { data: null, diff --git a/src/app/[orgId]/settings/(private)/billing/page.tsx b/src/app/[orgId]/settings/(private)/billing/page.tsx index b108d461..5f608e55 100644 --- a/src/app/[orgId]/settings/(private)/billing/page.tsx +++ b/src/app/[orgId]/settings/(private)/billing/page.tsx @@ -110,37 +110,42 @@ const planOptions: PlanOption[] = [ // Tier limits mapping derived from limit sets const tierLimits: Record< Tier | "basic", - { users: number; sites: number; domains: number; remoteNodes: number } + { users: number; sites: number; domains: number; remoteNodes: number; organizations: number } > = { basic: { users: freeLimitSet[FeatureId.USERS]?.value ?? 0, sites: freeLimitSet[FeatureId.SITES]?.value ?? 0, domains: freeLimitSet[FeatureId.DOMAINS]?.value ?? 0, - remoteNodes: freeLimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 + remoteNodes: freeLimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0, + organizations: freeLimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0 }, tier1: { users: tier1LimitSet[FeatureId.USERS]?.value ?? 0, sites: tier1LimitSet[FeatureId.SITES]?.value ?? 0, domains: tier1LimitSet[FeatureId.DOMAINS]?.value ?? 0, - remoteNodes: tier1LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 + remoteNodes: tier1LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0, + organizations: tier1LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0 }, tier2: { users: tier2LimitSet[FeatureId.USERS]?.value ?? 0, sites: tier2LimitSet[FeatureId.SITES]?.value ?? 0, domains: tier2LimitSet[FeatureId.DOMAINS]?.value ?? 0, - remoteNodes: tier2LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 + remoteNodes: tier2LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0, + organizations: tier2LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0 }, tier3: { users: tier3LimitSet[FeatureId.USERS]?.value ?? 0, sites: tier3LimitSet[FeatureId.SITES]?.value ?? 0, domains: tier3LimitSet[FeatureId.DOMAINS]?.value ?? 0, - remoteNodes: tier3LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 + remoteNodes: tier3LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0, + organizations: tier3LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0 }, enterprise: { users: 0, // Custom for enterprise sites: 0, // Custom for enterprise domains: 0, // Custom for enterprise - remoteNodes: 0 // Custom for enterprise + remoteNodes: 0, // Custom for enterprise + organizations: 0 // Custom for enterprise } }; @@ -179,6 +184,7 @@ export default function BillingPage() { const SITES = "sites"; const DOMAINS = "domains"; const REMOTE_EXIT_NODES = "remoteExitNodes"; + const ORGINIZATIONS = "organizations"; // Confirmation dialog state const [showConfirmDialog, setShowConfirmDialog] = useState(false); @@ -619,6 +625,16 @@ export default function BillingPage() { }); } + // Check organizations + const organizationsUsage = getUsageValue(ORGINIZATIONS); + if (limits.organizations > 0 && organizationsUsage > limits.organizations) { + violations.push({ + feature: "Organizations", + currentUsage: organizationsUsage, + newLimit: limits.organizations + }); + } + return violations; }; @@ -855,6 +871,41 @@ export default function BillingPage() { )} + + + {t("billingOrganizations") || + "Organizations"} + + + {isOverLimit(ORGINIZATIONS) ? ( + + + + + {getLimitValue(ORGINIZATIONS) ?? + t("billingUnlimited") ?? + "∞"}{" "} + {getLimitValue(ORGINIZATIONS) !== + null && "organizations"} + + + +

{t("billingUsageExceedsLimit", { current: getUsageValue(ORGINIZATIONS), limit: getLimitValue(ORGINIZATIONS) ?? 0 }) || `Current usage (${getUsageValue(ORGINIZATIONS)}) exceeds limit (${getLimitValue(ORGINIZATIONS)})`}

+
+
+ ) : ( + <> + {getLimitValue(ORGINIZATIONS) ?? + t("billingUnlimited") ?? + "∞"}{" "} + {getLimitValue(ORGINIZATIONS) !== + null && "organizations"} + + )} +
+
{t("billingRemoteNodes") ||