diff --git a/messages/en-US.json b/messages/en-US.json index 3e825711..44d980c5 100644 --- a/messages/en-US.json +++ b/messages/en-US.json @@ -1031,6 +1031,7 @@ "pangolinSetup": "Setup - Pangolin", "orgNameRequired": "Organization name is required", "orgIdRequired": "Organization ID is required", + "orgIdMaxLength": "Organization ID must be at most 32 characters", "orgErrorCreate": "An error occurred while creating org", "pageNotFound": "Page Not Found", "pageNotFoundDescription": "Oops! The page you're looking for doesn't exist.", @@ -1266,6 +1267,7 @@ "sidebarLogAndAnalytics": "Log & Analytics", "sidebarBluePrints": "Blueprints", "sidebarOrganization": "Organization", + "sidebarBillingAndLicenses": "Billing & Licenses", "sidebarLogsAnalytics": "Analytics", "blueprints": "Blueprints", "blueprintsDescription": "Apply declarative configurations and view previous runs", @@ -1427,6 +1429,7 @@ "billingSites": "Sites", "billingUsers": "Users", "billingDomains": "Domains", + "billingOrganizations": "Orgs", "billingRemoteExitNodes": "Remote Nodes", "billingNoLimitConfigured": "No limit configured", "billingEstimatedPeriod": "Estimated Billing Period", @@ -1469,6 +1472,7 @@ "failed": "Failed", "createNewOrgDescription": "Create a new organization", "organization": "Organization", + "primary": "Primary", "port": "Port", "securityKeyManage": "Manage Security Keys", "securityKeyDescription": "Add or remove security keys for passwordless authentication", diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index ca46e207..7c252b8b 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -55,7 +55,9 @@ export const orgs = pgTable("orgs", { .notNull() .default(0), sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format) - sshCaPublicKey: text("sshCaPublicKey") // SSH CA public key (OpenSSH format) + sshCaPublicKey: text("sshCaPublicKey"), // SSH CA public key (OpenSSH format) + isBillingOrg: boolean("isBillingOrg"), + billingOrgId: varchar("billingOrgId") }); export const orgDomains = pgTable("orgDomains", { diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index ce08dea1..04d4338a 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -47,7 +47,9 @@ export const orgs = sqliteTable("orgs", { .notNull() .default(0), sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format) - sshCaPublicKey: text("sshCaPublicKey") // SSH CA public key (OpenSSH format) + sshCaPublicKey: text("sshCaPublicKey"), // SSH CA public key (OpenSSH format) + isBillingOrg: integer("isBillingOrg", { mode: "boolean" }), + billingOrgId: text("billingOrgId") }); export const userDomains = sqliteTable("userDomains", { diff --git a/server/lib/billing/features.ts b/server/lib/billing/features.ts index 3fec53b4..6063b470 100644 --- a/server/lib/billing/features.ts +++ b/server/lib/billing/features.ts @@ -4,6 +4,7 @@ export enum FeatureId { EGRESS_DATA_MB = "egressDataMb", DOMAINS = "domains", REMOTE_EXIT_NODES = "remoteExitNodes", + ORGINIZATIONS = "organizations", TIER1 = "tier1" } @@ -19,6 +20,8 @@ export async function getFeatureDisplayName(featureId: FeatureId): Promise; -export const sandboxLimitSet: LimitSet = { - [FeatureId.USERS]: { value: 1, description: "Sandbox limit" }, - [FeatureId.SITES]: { value: 1, description: "Sandbox limit" }, - [FeatureId.DOMAINS]: { value: 0, description: "Sandbox limit" }, - [FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" }, -}; - export const freeLimitSet: LimitSet = { [FeatureId.SITES]: { value: 5, description: "Basic limit" }, [FeatureId.USERS]: { value: 5, description: "Basic limit" }, [FeatureId.DOMAINS]: { value: 5, description: "Basic limit" }, [FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Basic limit" }, + [FeatureId.ORGINIZATIONS]: { value: 1, description: "Basic limit" }, }; export const tier1LimitSet: LimitSet = { @@ -26,6 +20,7 @@ export const tier1LimitSet: LimitSet = { [FeatureId.SITES]: { value: 10, description: "Home limit" }, [FeatureId.DOMAINS]: { value: 10, description: "Home limit" }, [FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Home limit" }, + [FeatureId.ORGINIZATIONS]: { value: 1, description: "Home limit" }, }; export const tier2LimitSet: LimitSet = { @@ -45,6 +40,10 @@ export const tier2LimitSet: LimitSet = { value: 3, description: "Team limit" }, + [FeatureId.ORGINIZATIONS]: { + value: 1, + description: "Team limit" + } }; export const tier3LimitSet: LimitSet = { @@ -64,4 +63,8 @@ export const tier3LimitSet: LimitSet = { value: 20, description: "Business limit" }, + [FeatureId.ORGINIZATIONS]: { + value: 5, + description: "Business limit" + }, }; diff --git a/server/lib/billing/usageService.ts b/server/lib/billing/usageService.ts index c4ae2925..a7786c76 100644 --- a/server/lib/billing/usageService.ts +++ b/server/lib/billing/usageService.ts @@ -1,34 +1,19 @@ import { eq, sql, and } from "drizzle-orm"; -import { v4 as uuidv4 } from "uuid"; -import { PutObjectCommand } from "@aws-sdk/client-s3"; import { db, usage, customers, - sites, - newts, limits, Usage, Limit, - Transaction + Transaction, + orgs } from "@server/db"; import { FeatureId, getFeatureMeterId } from "./features"; import logger from "@server/logger"; -import { sendToClient } from "#dynamic/routers/ws"; import { build } from "@server/build"; -import { s3Client } from "@server/lib/s3"; import cache from "@server/lib/cache"; -interface StripeEvent { - identifier?: string; - timestamp: number; - event_name: string; - payload: { - value: number; - stripe_customer_id: string; - }; -} - export function noop() { if (build !== "saas") { return true; @@ -37,41 +22,11 @@ export function noop() { } export class UsageService { - private bucketName: string | undefined; - private events: StripeEvent[] = []; - private lastUploadTime: number = Date.now(); - private isUploading: boolean = false; constructor() { if (noop()) { return; } - - // this.bucketName = process.env.S3_BUCKET || undefined; - - // // Periodically check and upload events - // setInterval(() => { - // this.checkAndUploadEvents().catch((err) => { - // logger.error("Error in periodic event upload:", err); - // }); - // }, 30000); // every 30 seconds - - // // Handle graceful shutdown on SIGTERM - // process.on("SIGTERM", async () => { - // logger.info( - // "SIGTERM received, uploading events before shutdown..." - // ); - // await this.forceUpload(); - // logger.info("Events uploaded, proceeding with shutdown"); - // }); - - // // Handle SIGINT as well (Ctrl+C) - // process.on("SIGINT", async () => { - // logger.info("SIGINT received, uploading events before shutdown..."); - // await this.forceUpload(); - // logger.info("Events uploaded, proceeding with shutdown"); - // process.exit(0); - // }); } /** @@ -91,6 +46,8 @@ export class UsageService { return null; } + let orgIdToUse = await this.getBillingOrg(orgId, transaction); + // Truncate value to 11 decimal places value = this.truncateValue(value); @@ -100,20 +57,10 @@ export class UsageService { while (attempt <= maxRetries) { try { - // Get subscription data for this org (with caching) - const customerId = await this.getCustomerId(orgId, featureId); - - if (!customerId) { - logger.warn( - `No subscription data found for org ${orgId} and feature ${featureId}` - ); - return null; - } - let usage; if (transaction) { usage = await this.internalAddUsage( - orgId, + orgIdToUse, featureId, value, transaction @@ -121,7 +68,7 @@ export class UsageService { } else { await db.transaction(async (trx) => { usage = await this.internalAddUsage( - orgId, + orgIdToUse, featureId, value, trx @@ -129,11 +76,6 @@ export class UsageService { }); } - // Log event for Stripe - // if (privateConfig.getRawPrivateConfig().flags.usage_reporting) { - // await this.logStripeEvent(featureId, value, customerId); - // } - return usage || null; } catch (error: any) { // Check if this is a deadlock error @@ -150,7 +92,7 @@ export class UsageService { const delay = baseDelay + jitter; logger.warn( - `Deadlock detected for ${orgId}/${featureId}, retrying attempt ${attempt}/${maxRetries} after ${delay.toFixed(0)}ms` + `Deadlock detected for ${orgIdToUse}/${featureId}, retrying attempt ${attempt}/${maxRetries} after ${delay.toFixed(0)}ms` ); await new Promise((resolve) => setTimeout(resolve, delay)); @@ -158,7 +100,7 @@ export class UsageService { } logger.error( - `Failed to add usage for ${orgId}/${featureId} after ${attempt} attempts:`, + `Failed to add usage for ${orgIdToUse}/${featureId} after ${attempt} attempts:`, error ); break; @@ -169,7 +111,7 @@ export class UsageService { } private async internalAddUsage( - orgId: string, + orgId: string, // here the orgId is the billing org already resolved by getBillingOrg in updateCount featureId: FeatureId, value: number, trx: Transaction @@ -188,17 +130,22 @@ export class UsageService { featureId, orgId, meterId, - latestValue: value, + instantaneousValue: value || 0, + latestValue: value || 0, updatedAt: Math.floor(Date.now() / 1000) }) .onConflictDoUpdate({ target: usage.usageId, set: { - latestValue: sql`${usage.latestValue} + ${value}` + instantaneousValue: sql`COALESCE(${usage.instantaneousValue}, 0) + ${value}` } }) .returning(); + logger.debug( + `Added usage for org ${orgId} feature ${featureId}: +${value}, new instantaneousValue: ${returnUsage.instantaneousValue}` + ); + return returnUsage; } @@ -221,18 +168,10 @@ export class UsageService { if (noop()) { return; } - try { - if (!customerId) { - customerId = - (await this.getCustomerId(orgId, featureId)) || undefined; - if (!customerId) { - logger.warn( - `No subscription data found for org ${orgId} and feature ${featureId}` - ); - return; - } - } + let orgIdToUse = await this.getBillingOrg(orgId); + + try { // Truncate value to 11 decimal places if provided if (value !== undefined && value !== null) { value = this.truncateValue(value); @@ -242,7 +181,7 @@ export class UsageService { await db.transaction(async (trx) => { // Get existing meter record - const usageId = `${orgId}-${featureId}`; + const usageId = `${orgIdToUse}-${featureId}`; // Get current usage record [currentUsage] = await trx .select() @@ -264,7 +203,7 @@ export class UsageService { await trx.insert(usage).values({ usageId, featureId, - orgId, + orgId: orgIdToUse, meterId, instantaneousValue: value || 0, latestValue: value || 0, @@ -278,7 +217,7 @@ export class UsageService { // } } catch (error) { logger.error( - `Failed to update count usage for ${orgId}/${featureId}:`, + `Failed to update count usage for ${orgIdToUse}/${featureId}:`, error ); } @@ -288,7 +227,9 @@ export class UsageService { orgId: string, featureId: FeatureId ): Promise { - const cacheKey = `customer_${orgId}_${featureId}`; + let orgIdToUse = await this.getBillingOrg(orgId); + + const cacheKey = `customer_${orgIdToUse}_${featureId}`; const cached = cache.get(cacheKey); if (cached) { @@ -302,7 +243,7 @@ export class UsageService { customerId: customers.customerId }) .from(customers) - .where(eq(customers.orgId, orgId)) + .where(eq(customers.orgId, orgIdToUse)) .limit(1); if (!customer) { @@ -317,112 +258,13 @@ export class UsageService { return customerId; } catch (error) { logger.error( - `Failed to get subscription data for ${orgId}/${featureId}:`, + `Failed to get subscription data for ${orgIdToUse}/${featureId}:`, error ); return null; } } - private async logStripeEvent( - featureId: FeatureId, - value: number, - customerId: string - ): Promise { - // Truncate value to 11 decimal places before sending to Stripe - const truncatedValue = this.truncateValue(value); - - const event: StripeEvent = { - identifier: uuidv4(), - timestamp: Math.floor(new Date().getTime() / 1000), - event_name: featureId, - payload: { - value: truncatedValue, - stripe_customer_id: customerId - } - }; - - this.addEventToMemory(event); - await this.checkAndUploadEvents(); - } - - private addEventToMemory(event: StripeEvent): void { - if (!this.bucketName) { - logger.warn( - "S3 bucket name is not configured, skipping event storage." - ); - return; - } - this.events.push(event); - } - - private async checkAndUploadEvents(): Promise { - const now = Date.now(); - const timeSinceLastUpload = now - this.lastUploadTime; - - // Check if at least 1 minute has passed since last upload - if (timeSinceLastUpload >= 60000 && this.events.length > 0) { - await this.uploadEventsToS3(); - } - } - - private async uploadEventsToS3(): Promise { - if (!this.bucketName) { - logger.warn( - "S3 bucket name is not configured, skipping S3 upload." - ); - return; - } - - if (this.events.length === 0) { - return; - } - - // Check if already uploading - if (this.isUploading) { - logger.debug("Already uploading events, skipping"); - return; - } - - this.isUploading = true; - - try { - // Take a snapshot of current events and clear the array - const eventsToUpload = [...this.events]; - this.events = []; - this.lastUploadTime = Date.now(); - - const fileName = this.generateEventFileName(); - const fileContent = JSON.stringify(eventsToUpload, null, 2); - - // Upload to S3 - const uploadCommand = new PutObjectCommand({ - Bucket: this.bucketName, - Key: fileName, - Body: fileContent, - ContentType: "application/json" - }); - - await s3Client.send(uploadCommand); - - logger.info( - `Uploaded ${fileName} to S3 with ${eventsToUpload.length} events` - ); - } catch (error) { - logger.error("Failed to upload events to S3:", error); - // Note: Events are lost if upload fails. In a production system, - // you might want to add the events back to the array or implement retry logic - } finally { - this.isUploading = false; - } - } - - private generateEventFileName(): string { - const timestamp = new Date().toISOString().replace(/[:.]/g, "-"); - const uuid = uuidv4().substring(0, 8); - return `events-${timestamp}-${uuid}.json`; - } - public async getUsage( orgId: string, featureId: FeatureId, @@ -432,7 +274,9 @@ export class UsageService { return null; } - const usageId = `${orgId}-${featureId}`; + let orgIdToUse = await this.getBillingOrg(orgId, trx); + + const usageId = `${orgIdToUse}-${featureId}`; try { const [result] = await trx @@ -444,7 +288,7 @@ export class UsageService { if (!result) { // Lets create one if it doesn't exist using upsert to handle race conditions logger.info( - `Creating new usage record for ${orgId}/${featureId}` + `Creating new usage record for ${orgIdToUse}/${featureId}` ); const meterId = getFeatureMeterId(featureId); @@ -454,7 +298,7 @@ export class UsageService { .values({ usageId, featureId, - orgId, + orgId: orgIdToUse, meterId, latestValue: 0, updatedAt: Math.floor(Date.now() / 1000) @@ -476,7 +320,7 @@ export class UsageService { } catch (insertError) { // Fallback: try to fetch existing record in case of any insert issues logger.warn( - `Insert failed for ${orgId}/${featureId}, attempting to fetch existing record:`, + `Insert failed for ${orgIdToUse}/${featureId}, attempting to fetch existing record:`, insertError ); const [existingUsage] = await trx @@ -491,19 +335,41 @@ export class UsageService { return result; } catch (error) { logger.error( - `Failed to get usage for ${orgId}/${featureId}:`, + `Failed to get usage for ${orgIdToUse}/${featureId}:`, error ); throw error; } } - public async forceUpload(): Promise { - if (this.events.length > 0) { - // Force upload regardless of time - this.lastUploadTime = 0; // Reset to force upload - await this.uploadEventsToS3(); + public async getBillingOrg( + orgId: string, + trx: Transaction | typeof db = db + ): Promise { + let orgIdToUse = orgId; + + // get the org + const [org] = await trx + .select() + .from(orgs) + .where(eq(orgs.orgId, orgId)) + .limit(1); + + if (!org) { + throw new Error(`Organization with ID ${orgId} not found`); } + + if (!org.isBillingOrg) { + if (org.billingOrgId) { + orgIdToUse = org.billingOrgId; + } else { + throw new Error( + `Organization ${orgId} is not a billing org and does not have a billingOrgId set` + ); + } + } + + return orgIdToUse; } public async checkLimitSet( @@ -515,6 +381,9 @@ export class UsageService { if (noop()) { return false; } + + let orgIdToUse = await this.getBillingOrg(orgId, trx); + // This method should check the current usage against the limits set for the organization // and kick out all of the sites on the org let hasExceededLimits = false; @@ -528,7 +397,7 @@ export class UsageService { .from(limits) .where( and( - eq(limits.orgId, orgId), + eq(limits.orgId, orgIdToUse), eq(limits.featureId, featureId) ) ); @@ -537,11 +406,11 @@ export class UsageService { orgLimits = await trx .select() .from(limits) - .where(eq(limits.orgId, orgId)); + .where(eq(limits.orgId, orgIdToUse)); } if (orgLimits.length === 0) { - logger.debug(`No limits set for org ${orgId}`); + logger.debug(`No limits set for org ${orgIdToUse}`); return false; } @@ -552,7 +421,7 @@ export class UsageService { currentUsage = usage; } else { currentUsage = await this.getUsage( - orgId, + orgIdToUse, limit.featureId as FeatureId, trx ); @@ -563,10 +432,10 @@ export class UsageService { currentUsage?.latestValue || 0; logger.debug( - `Current usage for org ${orgId} on feature ${limit.featureId}: ${usageValue}` + `Current usage for org ${orgIdToUse} on feature ${limit.featureId}: ${usageValue}` ); logger.debug( - `Limit for org ${orgId} on feature ${limit.featureId}: ${limit.value}` + `Limit for org ${orgIdToUse} on feature ${limit.featureId}: ${limit.value}` ); if ( currentUsage && @@ -574,7 +443,7 @@ export class UsageService { usageValue > limit.value ) { logger.debug( - `Org ${orgId} has exceeded limit for ${limit.featureId}: ` + + `Org ${orgIdToUse} has exceeded limit for ${limit.featureId}: ` + `${usageValue} > ${limit.value}` ); hasExceededLimits = true; @@ -582,7 +451,7 @@ export class UsageService { } } } catch (error) { - logger.error(`Error checking limits for org ${orgId}:`, error); + logger.error(`Error checking limits for org ${orgIdToUse}:`, error); } return hasExceededLimits; diff --git a/server/lib/createUserAccountOrg.ts b/server/lib/createUserAccountOrg.ts deleted file mode 100644 index 207594a5..00000000 --- a/server/lib/createUserAccountOrg.ts +++ /dev/null @@ -1,206 +0,0 @@ -import { isValidCIDR } from "@server/lib/validators"; -import { getNextAvailableOrgSubnet } from "@server/lib/ip"; -import { - actions, - apiKeyOrg, - apiKeys, - db, - domains, - Org, - orgDomains, - orgs, - roleActions, - roles, - userOrgs -} from "@server/db"; -import { eq } from "drizzle-orm"; -import { defaultRoleAllowedActions } from "@server/routers/role"; -import { FeatureId, limitsService, sandboxLimitSet } from "@server/lib/billing"; -import { createCustomer } from "#dynamic/lib/billing"; -import { usageService } from "@server/lib/billing/usageService"; -import config from "@server/lib/config"; -import { generateCA } from "@server/private/lib/sshCA"; -import { encrypt } from "@server/lib/crypto"; - -export async function createUserAccountOrg( - userId: string, - userEmail: string -): Promise<{ - success: boolean; - org?: { - orgId: string; - name: string; - subnet: string; - }; - error?: string; -}> { - // const subnet = await getNextAvailableOrgSubnet(); - const orgId = "org_" + userId; - const name = `${userEmail}'s Organization`; - - // if (!isValidCIDR(subnet)) { - // return { - // success: false, - // error: "Invalid subnet format. Please provide a valid CIDR notation." - // }; - // } - - // // make sure the subnet is unique - // const subnetExists = await db - // .select() - // .from(orgs) - // .where(eq(orgs.subnet, subnet)) - // .limit(1); - - // if (subnetExists.length > 0) { - // return { success: false, error: `Subnet ${subnet} already exists` }; - // } - - // make sure the orgId is unique - const orgExists = await db - .select() - .from(orgs) - .where(eq(orgs.orgId, orgId)) - .limit(1); - - if (orgExists.length > 0) { - return { - success: false, - error: `Organization with ID ${orgId} already exists` - }; - } - - let error = ""; - let org: Org | null = null; - - await db.transaction(async (trx) => { - const allDomains = await trx - .select() - .from(domains) - .where(eq(domains.configManaged, true)); - - const utilitySubnet = config.getRawConfig().orgs.utility_subnet_group; - - // Generate SSH CA keys for the org - // const ca = generateCA(`${orgId}-ca`); - // const encryptionKey = config.getRawConfig().server.secret!; - // const encryptedCaPrivateKey = encrypt(ca.privateKeyPem, encryptionKey); - - const newOrg = await trx - .insert(orgs) - .values({ - orgId, - name, - // subnet - subnet: "100.90.128.0/24", // TODO: this should not be hardcoded - or can it be the same in all orgs? - utilitySubnet: utilitySubnet, - createdAt: new Date().toISOString(), - // sshCaPrivateKey: encryptedCaPrivateKey, - // sshCaPublicKey: ca.publicKeyOpenSSH - }) - .returning(); - - if (newOrg.length === 0) { - error = "Failed to create organization"; - trx.rollback(); - return; - } - - org = newOrg[0]; - - // Create admin role within the same transaction - const [insertedRole] = await trx - .insert(roles) - .values({ - orgId: newOrg[0].orgId, - isAdmin: true, - name: "Admin", - description: "Admin role with the most permissions" - }) - .returning({ roleId: roles.roleId }); - - if (!insertedRole || !insertedRole.roleId) { - error = "Failed to create Admin role"; - trx.rollback(); - return; - } - - const roleId = insertedRole.roleId; - - // Get all actions and create role actions - const actionIds = await trx.select().from(actions).execute(); - - if (actionIds.length > 0) { - await trx.insert(roleActions).values( - actionIds.map((action) => ({ - roleId, - actionId: action.actionId, - orgId: newOrg[0].orgId - })) - ); - } - - if (allDomains.length) { - await trx.insert(orgDomains).values( - allDomains.map((domain) => ({ - orgId: newOrg[0].orgId, - domainId: domain.domainId - })) - ); - } - - await trx.insert(userOrgs).values({ - userId, - orgId: newOrg[0].orgId, - roleId: roleId, - isOwner: true - }); - - const memberRole = await trx - .insert(roles) - .values({ - name: "Member", - description: "Members can only view resources", - orgId - }) - .returning(); - - await trx.insert(roleActions).values( - defaultRoleAllowedActions.map((action) => ({ - roleId: memberRole[0].roleId, - actionId: action, - orgId - })) - ); - }); - - await limitsService.applyLimitSetToOrg(orgId, sandboxLimitSet); - - if (!org) { - return { success: false, error: "Failed to create org" }; - } - - if (error) { - return { - success: false, - error: `Failed to create org: ${error}` - }; - } - - // make sure we have the stripe customer - const customerId = await createCustomer(orgId, userEmail); - - if (customerId) { - await usageService.updateCount(orgId, FeatureId.USERS, 1, customerId); // Only 1 because we are crating the org - } - - return { - org: { - orgId, - name, - // subnet - subnet: "100.90.128.0/24" - }, - success: true - }; -} diff --git a/server/lib/deleteOrg.ts b/server/lib/deleteOrg.ts index 7295555d..c0656c2a 100644 --- a/server/lib/deleteOrg.ts +++ b/server/lib/deleteOrg.ts @@ -4,14 +4,18 @@ import { clientSitesAssociationsCache, db, domains, + exitNodeOrgs, + exitNodes, olms, orgDomains, orgs, + remoteExitNodes, resources, - sites + sites, + userOrgs } from "@server/db"; import { newts, newtSessions } from "@server/db"; -import { eq, and, inArray, sql } from "drizzle-orm"; +import { eq, and, inArray, sql, count, countDistinct } from "drizzle-orm"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; @@ -19,6 +23,8 @@ import { sendToClient } from "#dynamic/routers/ws"; import { deletePeer } from "@server/routers/gerbil/peers"; import { OlmErrorCodes } from "@server/routers/olm/error"; import { sendTerminateClient } from "@server/routers/client/terminate"; +import { usageService } from "./billing/usageService"; +import { FeatureId } from "./billing"; export type DeleteOrgByIdResult = { deletedNewtIds: string[]; @@ -60,6 +66,11 @@ export async function deleteOrgById( const deletedNewtIds: string[] = []; const olmsToTerminate: string[] = []; + let domainCount: number | null = null; + let siteCount: number | null = null; + let userCount: number | null = null; + let remoteExitNodeCount: number | null = null; + await db.transaction(async (trx) => { for (const site of orgSites) { if (site.pubKey) { @@ -74,9 +85,7 @@ export async function deleteOrgById( deletedNewtIds.push(deletedNewt.newtId); await trx .delete(newtSessions) - .where( - eq(newtSessions.newtId, deletedNewt.newtId) - ); + .where(eq(newtSessions.newtId, deletedNewt.newtId)); } } } @@ -137,9 +146,74 @@ export async function deleteOrgById( .where(inArray(domains.domainId, domainIdsToDelete)); } await trx.delete(resources).where(eq(resources.orgId, orgId)); + + await usageService.add(orgId, FeatureId.ORGINIZATIONS, -1, trx); // here we are decreasing the org count BEFORE deleting the org because we need to still be able to get the org to get the billing org inside of here + await trx.delete(orgs).where(eq(orgs.orgId, orgId)); + + if (org.billingOrgId) { + const billingOrgs = await trx + .select() + .from(orgs) + .where(eq(orgs.billingOrgId, org.billingOrgId)); + + if (billingOrgs.length > 0) { + const billingOrgIds = billingOrgs.map((org) => org.orgId); + + const [domainCountRes] = await trx + .select({ count: count() }) + .from(orgDomains) + .where(inArray(orgDomains.orgId, billingOrgIds)); + + domainCount = domainCountRes.count; + + const [siteCountRes] = await trx + .select({ count: count() }) + .from(sites) + .where(inArray(sites.orgId, billingOrgIds)); + + siteCount = siteCountRes.count; + + const [userCountRes] = await trx + .select({ count: countDistinct(userOrgs.userId) }) + .from(userOrgs) + .where(inArray(userOrgs.orgId, billingOrgIds)); + + userCount = userCountRes.count; + + const [remoteExitNodeCountRes] = await trx + .select({ count: countDistinct(exitNodeOrgs.exitNodeId) }) + .from(exitNodeOrgs) + .where(inArray(exitNodeOrgs.orgId, billingOrgIds)); + + remoteExitNodeCount = remoteExitNodeCountRes.count; + } + } }); + if (org.billingOrgId) { + usageService.updateCount( + org.billingOrgId, + FeatureId.DOMAINS, + domainCount ?? 0 + ); + usageService.updateCount( + org.billingOrgId, + FeatureId.SITES, + siteCount ?? 0 + ); + usageService.updateCount( + org.billingOrgId, + FeatureId.USERS, + userCount ?? 0 + ); + usageService.updateCount( + org.billingOrgId, + FeatureId.REMOTE_EXIT_NODES, + remoteExitNodeCount ?? 0 + ); + } + return { deletedNewtIds, olmsToTerminate }; } @@ -155,15 +229,13 @@ export function sendTerminationMessages(result: DeleteOrgByIdResult): void { ); } for (const olmId of result.olmsToTerminate) { - sendTerminateClient( - 0, - OlmErrorCodes.TERMINATED_REKEYED, - olmId - ).catch((error) => { - logger.error( - "Failed to send termination message to olm:", - error - ); - }); + sendTerminateClient(0, OlmErrorCodes.TERMINATED_REKEYED, olmId).catch( + (error) => { + logger.error( + "Failed to send termination message to olm:", + error + ); + } + ); } } diff --git a/server/lib/userOrg.ts b/server/lib/userOrg.ts new file mode 100644 index 00000000..6ed10039 --- /dev/null +++ b/server/lib/userOrg.ts @@ -0,0 +1,142 @@ +import { + db, + Org, + orgs, + resources, + siteResources, + sites, + Transaction, + UserOrg, + userOrgs, + userResources, + userSiteResources, + userSites +} from "@server/db"; +import { eq, and, inArray, ne, exists } from "drizzle-orm"; +import { usageService } from "@server/lib/billing/usageService"; +import { FeatureId } from "@server/lib/billing"; + +export async function assignUserToOrg( + org: Org, + values: typeof userOrgs.$inferInsert, + trx: Transaction | typeof db = db +) { + const [userOrg] = await trx.insert(userOrgs).values(values).returning(); + + // calculate if the user is in any other of the orgs before we count it as an add to the billing org + if (org.billingOrgId) { + const otherBillingOrgs = await trx + .select() + .from(orgs) + .where( + and( + eq(orgs.billingOrgId, org.billingOrgId), + ne(orgs.orgId, org.orgId) + ) + ); + + const billingOrgIds = otherBillingOrgs.map((o) => o.orgId); + + const orgsInBillingDomainThatTheUserIsStillIn = await trx + .select() + .from(userOrgs) + .where( + and( + eq(userOrgs.userId, userOrg.userId), + inArray(userOrgs.orgId, billingOrgIds) + ) + ); + + if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) { + await usageService.add(org.orgId, FeatureId.USERS, 1, trx); + } + } +} + +export async function removeUserFromOrg( + org: Org, + userId: string, + trx: Transaction | typeof db = db +) { + await trx + .delete(userOrgs) + .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, org.orgId))); + + await trx.delete(userResources).where( + and( + eq(userResources.userId, userId), + exists( + trx + .select() + .from(resources) + .where( + and( + eq(resources.resourceId, userResources.resourceId), + eq(resources.orgId, org.orgId) + ) + ) + ) + ) + ); + + await trx.delete(userSiteResources).where( + and( + eq(userSiteResources.userId, userId), + exists( + trx + .select() + .from(siteResources) + .where( + and( + eq( + siteResources.siteResourceId, + userSiteResources.siteResourceId + ), + eq(siteResources.orgId, org.orgId) + ) + ) + ) + ) + ); + + await trx.delete(userSites).where( + and( + eq(userSites.userId, userId), + exists( + db + .select() + .from(sites) + .where( + and( + eq(sites.siteId, userSites.siteId), + eq(sites.orgId, org.orgId) + ) + ) + ) + ) + ); + + // calculate if the user is in any other of the orgs before we count it as an remove to the billing org + if (org.billingOrgId) { + const billingOrgs = await trx + .select() + .from(orgs) + .where(eq(orgs.billingOrgId, org.billingOrgId)); + + const billingOrgIds = billingOrgs.map((o) => o.orgId); + + const orgsInBillingDomainThatTheUserIsStillIn = await trx + .select() + .from(userOrgs) + .where( + and( + eq(userOrgs.userId, userId), + inArray(userOrgs.orgId, billingOrgIds) + ) + ); + + if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) { + await usageService.add(org.orgId, FeatureId.USERS, -1, trx); + } + } +} diff --git a/server/private/lib/billing/getOrgTierData.ts b/server/private/lib/billing/getOrgTierData.ts index 24d658c0..d87f2c38 100644 --- a/server/private/lib/billing/getOrgTierData.ts +++ b/server/private/lib/billing/getOrgTierData.ts @@ -12,7 +12,8 @@ */ import { build } from "@server/build"; -import { db, customers, subscriptions } from "@server/db"; +import { db, customers, subscriptions, orgs } from "@server/db"; +import logger from "@server/logger"; import { Tier } from "@server/types/Tiers"; import { eq, and, ne } from "drizzle-orm"; @@ -27,37 +28,60 @@ export async function getOrgTierData( } try { + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, orgId)) + .limit(1); + + if (!org) { + return { tier, active }; + } + + let orgIdToUse = org.orgId; + if (!org.isBillingOrg) { + if (!org.billingOrgId) { + logger.warn( + `Org ${orgId} is not a billing org and does not have a billingOrgId` + ); + return { tier, active }; + } + orgIdToUse = org.billingOrgId; + } + // Get customer for org const [customer] = await db .select() .from(customers) - .where(eq(customers.orgId, orgId)) + .where(eq(customers.orgId, orgIdToUse)) .limit(1); - if (customer) { - // Query for active subscriptions that are not license type - const [subscription] = await db - .select() - .from(subscriptions) - .where( - and( - eq(subscriptions.customerId, customer.customerId), - eq(subscriptions.status, "active"), - ne(subscriptions.type, "license") - ) - ) - .limit(1); + if (!customer) { + return { tier, active }; + } - if (subscription) { - // Validate that subscription.type is one of the expected tier values - if ( - subscription.type === "tier1" || - subscription.type === "tier2" || - subscription.type === "tier3" - ) { - tier = subscription.type; - active = true; - } + // Query for active subscriptions that are not license type + const [subscription] = await db + .select() + .from(subscriptions) + .where( + and( + eq(subscriptions.customerId, customer.customerId), + eq(subscriptions.status, "active"), + ne(subscriptions.type, "license") + ) + ) + .limit(1); + + if (subscription) { + // Validate that subscription.type is one of the expected tier values + if ( + subscription.type === "tier1" || + subscription.type === "tier2" || + subscription.type === "tier3" + ) { + tier = subscription.type; + active = true; } } } catch (error) { diff --git a/server/private/routers/billing/featureLifecycle.ts b/server/private/routers/billing/featureLifecycle.ts index 35345444..3e4b8a4a 100644 --- a/server/private/routers/billing/featureLifecycle.ts +++ b/server/private/routers/billing/featureLifecycle.ts @@ -15,7 +15,18 @@ import { SubscriptionType } from "./hooks/getSubType"; import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix"; import { Tier } from "@server/types/Tiers"; import logger from "@server/logger"; -import { db, idp, idpOrg, loginPage, loginPageBranding, loginPageBrandingOrg, loginPageOrg, orgs, resources, roles } from "@server/db"; +import { + db, + idp, + idpOrg, + loginPage, + loginPageBranding, + loginPageBrandingOrg, + loginPageOrg, + orgs, + resources, + roles +} from "@server/db"; import { eq } from "drizzle-orm"; /** @@ -59,10 +70,7 @@ async function capRetentionDays( } // Get current org settings - const [org] = await db - .select() - .from(orgs) - .where(eq(orgs.orgId, orgId)); + const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId)); if (!org) { logger.warn(`Org ${orgId} not found when capping retention days`); @@ -110,18 +118,13 @@ async function capRetentionDays( // Apply updates if needed if (needsUpdate) { - await db - .update(orgs) - .set(updates) - .where(eq(orgs.orgId, orgId)); + await db.update(orgs).set(updates).where(eq(orgs.orgId, orgId)); logger.info( `Successfully capped retention days for org ${orgId} to max ${maxRetentionDays} days` ); } else { - logger.debug( - `No retention day capping needed for org ${orgId}` - ); + logger.debug(`No retention day capping needed for org ${orgId}`); } } @@ -134,6 +137,35 @@ export async function handleTierChange( `Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}` ); + // Get all orgs that have this orgId as their billingOrgId + const associatedOrgs = await db + .select() + .from(orgs) + .where(eq(orgs.billingOrgId, orgId)); + + logger.info( + `Found ${associatedOrgs.length} org(s) associated with billing org ${orgId}` + ); + + // Loop over all associated orgs and apply tier changes + for (const org of associatedOrgs) { + await handleTierChangeForOrg(org.orgId, newTier, previousTier); + } + + logger.info( + `Completed tier change handling for all orgs associated with billing org ${orgId}` + ); +} + +async function handleTierChangeForOrg( + orgId: string, + newTier: SubscriptionType | null, + previousTier?: SubscriptionType | null +): Promise { + logger.info( + `Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}` + ); + // License subscriptions are handled separately and don't use the tier matrix if (newTier === "license") { logger.debug( @@ -314,9 +346,7 @@ async function disableLoginPageDomain(orgId: string): Promise { ); if (existingLoginPage) { - await db - .delete(loginPageOrg) - .where(eq(loginPageOrg.orgId, orgId)); + await db.delete(loginPageOrg).where(eq(loginPageOrg.orgId, orgId)); await db .delete(loginPage) diff --git a/server/private/routers/billing/getOrgSubscriptions.ts b/server/private/routers/billing/getOrgSubscriptions.ts index d2ee8c5b..718c98f4 100644 --- a/server/private/routers/billing/getOrgSubscriptions.ts +++ b/server/private/routers/billing/getOrgSubscriptions.ts @@ -112,11 +112,13 @@ export async function getOrgSubscriptionsData( throw new Error(`Not found`); } + const billingOrgId = org[0].billingOrgId || org[0].orgId; + // Get customer for org const customer = await db .select() .from(customers) - .where(eq(customers.orgId, orgId)) + .where(eq(customers.orgId, billingOrgId)) .limit(1); const subscriptionsWithItems: Array<{ diff --git a/server/private/routers/billing/getOrgUsage.ts b/server/private/routers/billing/getOrgUsage.ts index cf4e7585..4c9f22f3 100644 --- a/server/private/routers/billing/getOrgUsage.ts +++ b/server/private/routers/billing/getOrgUsage.ts @@ -85,10 +85,14 @@ export async function getOrgUsage( orgId, FeatureId.REMOTE_EXIT_NODES ); - const egressData = await usageService.getUsage( + const organizations = await usageService.getUsage( orgId, - FeatureId.EGRESS_DATA_MB + FeatureId.ORGINIZATIONS ); + // const egressData = await usageService.getUsage( + // orgId, + // FeatureId.EGRESS_DATA_MB + // ); if (sites) { usageData.push(sites); @@ -96,15 +100,18 @@ export async function getOrgUsage( if (users) { usageData.push(users); } - if (egressData) { - usageData.push(egressData); - } + // if (egressData) { + // usageData.push(egressData); + // } if (domains) { usageData.push(domains); } if (remoteExitNodes) { usageData.push(remoteExitNodes); } + if (organizations) { + usageData.push(organizations); + } const orgLimits = await db .select() diff --git a/server/private/routers/remoteExitNode/createRemoteExitNode.ts b/server/private/routers/remoteExitNode/createRemoteExitNode.ts index 14541736..6d5b5ea6 100644 --- a/server/private/routers/remoteExitNode/createRemoteExitNode.ts +++ b/server/private/routers/remoteExitNode/createRemoteExitNode.ts @@ -12,7 +12,14 @@ */ import { NextFunction, Request, Response } from "express"; -import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg } from "@server/db"; +import { + db, + exitNodes, + exitNodeOrgs, + ExitNode, + ExitNodeOrg, + orgs +} from "@server/db"; import HttpCode from "@server/types/HttpCode"; import { z } from "zod"; import { remoteExitNodes } from "@server/db"; @@ -25,7 +32,7 @@ import { createRemoteExitNodeSession } from "#private/auth/sessions/remoteExitNo import { fromError } from "zod-validation-error"; import { hashPassword, verifyPassword } from "@server/auth/password"; import logger from "@server/logger"; -import { and, eq } from "drizzle-orm"; +import { and, eq, inArray, ne } from "drizzle-orm"; import { getNextAvailableSubnet } from "@server/lib/exitNodes"; import { usageService } from "@server/lib/billing/usageService"; import { FeatureId } from "@server/lib/billing"; @@ -169,7 +176,17 @@ export async function createRemoteExitNode( ); } - let numExitNodeOrgs: ExitNodeOrg[] | undefined; + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, orgId)) + .limit(1); + + if (!org) { + return next( + createHttpError(HttpCode.NOT_FOUND, "Organization not found") + ); + } await db.transaction(async (trx) => { if (!existingExitNode) { @@ -217,19 +234,43 @@ export async function createRemoteExitNode( }); } - numExitNodeOrgs = await trx - .select() - .from(exitNodeOrgs) - .where(eq(exitNodeOrgs.orgId, orgId)); - }); + // calculate if the node is in any other of the orgs before we count it as an add to the billing org + if (org.billingOrgId) { + const otherBillingOrgs = await trx + .select() + .from(orgs) + .where( + and( + eq(orgs.billingOrgId, org.billingOrgId), + ne(orgs.orgId, orgId) + ) + ); - if (numExitNodeOrgs) { - await usageService.updateCount( - orgId, - FeatureId.REMOTE_EXIT_NODES, - numExitNodeOrgs.length - ); - } + const billingOrgIds = otherBillingOrgs.map((o) => o.orgId); + + const orgsInBillingDomainThatTheNodeIsStillIn = await trx + .select() + .from(exitNodeOrgs) + .where( + and( + eq( + exitNodeOrgs.exitNodeId, + existingExitNode.exitNodeId + ), + inArray(exitNodeOrgs.orgId, billingOrgIds) + ) + ); + + if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) { + await usageService.add( + orgId, + FeatureId.REMOTE_EXIT_NODES, + 1, + trx + ); + } + } + }); const token = generateSessionToken(); await createRemoteExitNodeSession(token, remoteExitNodeId); diff --git a/server/private/routers/remoteExitNode/deleteRemoteExitNode.ts b/server/private/routers/remoteExitNode/deleteRemoteExitNode.ts index 8337f05d..6ff6841c 100644 --- a/server/private/routers/remoteExitNode/deleteRemoteExitNode.ts +++ b/server/private/routers/remoteExitNode/deleteRemoteExitNode.ts @@ -13,9 +13,9 @@ import { NextFunction, Request, Response } from "express"; import { z } from "zod"; -import { db, ExitNodeOrg, exitNodeOrgs, exitNodes } from "@server/db"; +import { db, ExitNodeOrg, exitNodeOrgs, exitNodes, orgs } from "@server/db"; import { remoteExitNodes } from "@server/db"; -import { and, count, eq } from "drizzle-orm"; +import { and, count, eq, inArray } from "drizzle-orm"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; @@ -50,7 +50,8 @@ export async function deleteRemoteExitNode( const [remoteExitNode] = await db .select() .from(remoteExitNodes) - .where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId)); + .where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId)) + .limit(1); if (!remoteExitNode) { return next( @@ -70,7 +71,17 @@ export async function deleteRemoteExitNode( ); } - let numExitNodeOrgs: ExitNodeOrg[] | undefined; + const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId)); + + if (!org) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Org with ID ${orgId} not found` + ) + ); + } + await db.transaction(async (trx) => { await trx .delete(exitNodeOrgs) @@ -81,38 +92,39 @@ export async function deleteRemoteExitNode( ) ); - const [remainingExitNodeOrgs] = await trx - .select({ count: count() }) - .from(exitNodeOrgs) - .where(eq(exitNodeOrgs.exitNodeId, remoteExitNode.exitNodeId!)); + // calculate if the user is in any other of the orgs before we count it as an remove to the billing org + if (org.billingOrgId) { + const otherBillingOrgs = await trx + .select() + .from(orgs) + .where(eq(orgs.billingOrgId, org.billingOrgId)); - if (remainingExitNodeOrgs.count === 0) { - await trx - .delete(remoteExitNodes) + const billingOrgIds = otherBillingOrgs.map((o) => o.orgId); + + const orgsInBillingDomainThatTheNodeIsStillIn = await trx + .select() + .from(exitNodeOrgs) .where( - eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId) + and( + eq( + exitNodeOrgs.exitNodeId, + remoteExitNode.exitNodeId! + ), + inArray(exitNodeOrgs.orgId, billingOrgIds) + ) ); - await trx - .delete(exitNodes) - .where( - eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId!) + + if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) { + await usageService.add( + orgId, + FeatureId.REMOTE_EXIT_NODES, + -1, + trx ); + } } - - numExitNodeOrgs = await trx - .select() - .from(exitNodeOrgs) - .where(eq(exitNodeOrgs.orgId, orgId)); }); - if (numExitNodeOrgs) { - await usageService.updateCount( - orgId, - FeatureId.REMOTE_EXIT_NODES, - numExitNodeOrgs.length - ); - } - return response(res, { data: null, success: true, diff --git a/server/routers/auth/deleteMyAccount.ts b/server/routers/auth/deleteMyAccount.ts index 2c37cd09..8df11243 100644 --- a/server/routers/auth/deleteMyAccount.ts +++ b/server/routers/auth/deleteMyAccount.ts @@ -15,11 +15,10 @@ import { import { verifyPassword } from "@server/auth/password"; import { verifyTotpCode } from "@server/auth/totp"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; -import { - deleteOrgById, - sendTerminationMessages -} from "@server/lib/deleteOrg"; +import { deleteOrgById, sendTerminationMessages } from "@server/lib/deleteOrg"; import { UserType } from "@server/types/UserTypes"; +import { build } from "@server/build"; +import { getOrgTierData } from "#dynamic/lib/billing"; const deleteMyAccountBody = z.strictObject({ password: z.string().optional(), @@ -40,11 +39,6 @@ export type DeleteMyAccountSuccessResponse = { success: true; }; -/** - * Self-service account deletion (saas only). Returns preview when no password; - * requires password and optional 2FA code to perform deletion. Uses shared - * deleteOrgById for each owned org (delete-my-account may delete multiple orgs). - */ export async function deleteMyAccount( req: Request, res: Response, @@ -91,18 +85,35 @@ export async function deleteMyAccount( const ownedOrgsRows = await db .select({ - orgId: userOrgs.orgId + orgId: userOrgs.orgId, + isOwner: userOrgs.isOwner, + isBillingOrg: orgs.isBillingOrg }) .from(userOrgs) + .innerJoin(orgs, eq(userOrgs.orgId, orgs.orgId)) .where( - and( - eq(userOrgs.userId, userId), - eq(userOrgs.isOwner, true) - ) + and(eq(userOrgs.userId, userId), eq(userOrgs.isOwner, true)) ); const orgIds = ownedOrgsRows.map((r) => r.orgId); + if (build === "saas" && orgIds.length > 0) { + const primaryOrgId = ownedOrgsRows.find( + (r) => r.isBillingOrg && r.isOwner + )?.orgId; + if (primaryOrgId) { + const { tier, active } = await getOrgTierData(primaryOrgId); + if (active && tier) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "You must cancel your subscription before deleting your account" + ) + ); + } + } + } + if (!password) { const orgsWithNames = orgIds.length > 0 @@ -219,10 +230,7 @@ export async function deleteMyAccount( } catch (error) { logger.error(error); return next( - createHttpError( - HttpCode.INTERNAL_SERVER_ERROR, - "An error occurred" - ) + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") ); } } diff --git a/server/routers/auth/signup.ts b/server/routers/auth/signup.ts index 2605a026..cf8e4141 100644 --- a/server/routers/auth/signup.ts +++ b/server/routers/auth/signup.ts @@ -1,7 +1,7 @@ import { NextFunction, Request, Response } from "express"; import { db, users } from "@server/db"; import HttpCode from "@server/types/HttpCode"; -import { z } from "zod"; +import { email, z } from "zod"; import { fromError } from "zod-validation-error"; import createHttpError from "http-errors"; import response from "@server/lib/response"; @@ -21,7 +21,6 @@ import { hashPassword } from "@server/auth/password"; import { checkValidInvite } from "@server/auth/checkValidInvite"; import { passwordSchema } from "@server/auth/passwordSchema"; import { UserType } from "@server/types/UserTypes"; -import { createUserAccountOrg } from "@server/lib/createUserAccountOrg"; import { build } from "@server/build"; import resend, { AudienceIds, moveEmailToAudience } from "#dynamic/lib/resend"; @@ -31,7 +30,8 @@ export const signupBodySchema = z.object({ inviteToken: z.string().optional(), inviteId: z.string().optional(), termsAcceptedTimestamp: z.string().nullable().optional(), - marketingEmailConsent: z.boolean().optional() + marketingEmailConsent: z.boolean().optional(), + skipVerificationEmail: z.boolean().optional() }); export type SignUpBody = z.infer; @@ -62,7 +62,8 @@ export async function signup( inviteToken, inviteId, termsAcceptedTimestamp, - marketingEmailConsent + marketingEmailConsent, + skipVerificationEmail } = parsedBody.data; const passwordHash = await hashPassword(password); @@ -198,26 +199,6 @@ export async function signup( // orgId: null, // }); - if (build == "saas") { - const { success, error, org } = await createUserAccountOrg( - userId, - email - ); - if (!success) { - if (error) { - return next( - createHttpError(HttpCode.INTERNAL_SERVER_ERROR, error) - ); - } - return next( - createHttpError( - HttpCode.INTERNAL_SERVER_ERROR, - "Failed to create user account and organization" - ) - ); - } - } - const token = generateSessionToken(); const sess = await createSession(token, userId); const isSecure = req.protocol === "https"; @@ -235,7 +216,13 @@ export async function signup( } if (config.getRawConfig().flags?.require_email_verification) { - sendEmailVerificationCode(email, userId); + if (!skipVerificationEmail) { + sendEmailVerificationCode(email, userId); + } else { + logger.debug( + `User ${email} opted out of verification email during signup.` + ); + } return response(res, { data: { @@ -243,7 +230,9 @@ export async function signup( }, success: true, error: false, - message: `User created successfully. We sent an email to ${email} with a verification code.`, + message: skipVerificationEmail + ? "User created successfully. Please verify your email." + : `User created successfully. We sent an email to ${email} with a verification code.`, status: HttpCode.OK }); } diff --git a/server/routers/domain/createOrgDomain.ts b/server/routers/domain/createOrgDomain.ts index 35fb305f..ceb61b25 100644 --- a/server/routers/domain/createOrgDomain.ts +++ b/server/routers/domain/createOrgDomain.ts @@ -148,7 +148,6 @@ export async function createOrgDomain( } } - let numOrgDomains: OrgDomains[] | undefined; let aRecords: CreateDomainResponse["aRecords"]; let cnameRecords: CreateDomainResponse["cnameRecords"]; let txtRecords: CreateDomainResponse["txtRecords"]; @@ -347,20 +346,9 @@ export async function createOrgDomain( await trx.insert(dnsRecords).values(recordsToInsert); } - numOrgDomains = await trx - .select() - .from(orgDomains) - .where(eq(orgDomains.orgId, orgId)); + await usageService.add(orgId, FeatureId.DOMAINS, 1, trx); }); - if (numOrgDomains) { - await usageService.updateCount( - orgId, - FeatureId.DOMAINS, - numOrgDomains.length - ); - } - if (!returned) { return next( createHttpError( diff --git a/server/routers/domain/deleteOrgDomain.ts b/server/routers/domain/deleteOrgDomain.ts index 04829a13..4c347668 100644 --- a/server/routers/domain/deleteOrgDomain.ts +++ b/server/routers/domain/deleteOrgDomain.ts @@ -36,8 +36,6 @@ export async function deleteAccountDomain( } const { domainId, orgId } = parsed.data; - let numOrgDomains: OrgDomains[] | undefined; - await db.transaction(async (trx) => { const [existing] = await trx .select() @@ -79,20 +77,9 @@ export async function deleteAccountDomain( await trx.delete(domains).where(eq(domains.domainId, domainId)); - numOrgDomains = await trx - .select() - .from(orgDomains) - .where(eq(orgDomains.orgId, orgId)); + await usageService.add(orgId, FeatureId.DOMAINS, -1, trx); }); - if (numOrgDomains) { - await usageService.updateCount( - orgId, - FeatureId.DOMAINS, - numOrgDomains.length - ); - } - return response(res, { data: { success: true }, success: true, diff --git a/server/routers/external.ts b/server/routers/external.ts index a9d075a6..45ab58bb 100644 --- a/server/routers/external.ts +++ b/server/routers/external.ts @@ -65,9 +65,8 @@ authenticated.use(verifySessionUserMiddleware); authenticated.get("/pick-org-defaults", org.pickOrgDefaults); authenticated.get("/org/checkId", org.checkId); -if (build === "oss" || build === "enterprise") { - authenticated.put("/org", getUserOrgs, org.createOrg); -} + +authenticated.put("/org", getUserOrgs, org.createOrg); authenticated.get("/orgs", verifyUserIsServerAdmin, org.listOrgs); authenticated.get("/user/:userId/orgs", verifyIsLoggedInUser, org.listUserOrgs); @@ -87,16 +86,14 @@ authenticated.post( org.updateOrg ); -if (build !== "saas") { - authenticated.delete( - "/org/:orgId", - verifyOrgAccess, - verifyUserIsOrgOwner, - verifyUserHasAction(ActionsEnum.deleteOrg), - logActionAudit(ActionsEnum.deleteOrg), - org.deleteOrg - ); -} +authenticated.delete( + "/org/:orgId", + verifyOrgAccess, + verifyUserIsOrgOwner, + verifyUserHasAction(ActionsEnum.deleteOrg), + logActionAudit(ActionsEnum.deleteOrg), + org.deleteOrg +); authenticated.put( "/org/:orgId/site", diff --git a/server/routers/idp/validateOidcCallback.ts b/server/routers/idp/validateOidcCallback.ts index f4065c59..e3462185 100644 --- a/server/routers/idp/validateOidcCallback.ts +++ b/server/routers/idp/validateOidcCallback.ts @@ -36,6 +36,10 @@ import { build } from "@server/build"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { isSubscribed } from "#dynamic/lib/isSubscribed"; import { tierMatrix } from "@server/lib/billing/tierMatrix"; +import { + assignUserToOrg, + removeUserFromOrg +} from "@server/lib/userOrg"; const ensureTrailingSlash = (url: string): string => { return url; @@ -436,6 +440,7 @@ export async function validateOidcCallback( } } + // These are the orgs that the user should be provisioned into based on the IdP mappings and the token claims logger.debug("User org info", { userOrgInfo }); let existingUserId = existingUser?.userId; @@ -454,15 +459,32 @@ export async function validateOidcCallback( ); if (!existingUserOrgs.length) { - // delete all auto -provisioned user orgs - await db - .delete(userOrgs) + // delete all auto-provisioned user orgs + const autoProvisionedUserOrgs = await db + .select() + .from(userOrgs) .where( and( eq(userOrgs.userId, existingUser.userId), eq(userOrgs.autoProvisioned, true) ) ); + const orgIdsToRemove = autoProvisionedUserOrgs.map( + (uo) => uo.orgId + ); + if (orgIdsToRemove.length > 0) { + const orgsToRemove = await db + .select() + .from(orgs) + .where(inArray(orgs.orgId, orgIdsToRemove)); + for (const org of orgsToRemove) { + await removeUserFromOrg( + org, + existingUser.userId, + db + ); + } + } await calculateUserClientsForOrgs(existingUser.userId); @@ -484,7 +506,7 @@ export async function validateOidcCallback( } } - const orgUserCounts: { orgId: string; userCount: number }[] = []; + const orgUserCounts: { orgId: string; userCount: number }[] = []; // sync the user with the orgs and roles await db.transaction(async (trx) => { @@ -538,15 +560,14 @@ export async function validateOidcCallback( ); if (orgsToDelete.length > 0) { - await trx.delete(userOrgs).where( - and( - eq(userOrgs.userId, userId!), - inArray( - userOrgs.orgId, - orgsToDelete.map((org) => org.orgId) - ) - ) - ); + const orgIdsToRemove = orgsToDelete.map((org) => org.orgId); + const fullOrgsToRemove = await trx + .select() + .from(orgs) + .where(inArray(orgs.orgId, orgIdsToRemove)); + for (const org of fullOrgsToRemove) { + await removeUserFromOrg(org, userId!, trx); + } } // Update roles for existing auto-provisioned orgs where the role has changed @@ -587,15 +608,24 @@ export async function validateOidcCallback( ); if (orgsToAdd.length > 0) { - await trx.insert(userOrgs).values( - orgsToAdd.map((org) => ({ - userId: userId!, - orgId: org.orgId, - roleId: org.roleId, - autoProvisioned: true, - dateCreated: new Date().toISOString() - })) - ); + for (const org of orgsToAdd) { + const [fullOrg] = await trx + .select() + .from(orgs) + .where(eq(orgs.orgId, org.orgId)); + if (fullOrg) { + await assignUserToOrg( + fullOrg, + { + orgId: org.orgId, + userId: userId!, + roleId: org.roleId, + autoProvisioned: true, + }, + trx + ); + } + } } // Loop through all the orgs and get the total number of users from the userOrgs table diff --git a/server/routers/org/createOrg.ts b/server/routers/org/createOrg.ts index 22e9314e..59aa86d2 100644 --- a/server/routers/org/createOrg.ts +++ b/server/routers/org/createOrg.ts @@ -1,7 +1,7 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; import { db } from "@server/db"; -import { eq } from "drizzle-orm"; +import { and, count, eq } from "drizzle-orm"; import { domains, Org, @@ -24,15 +24,24 @@ import { OpenAPITags, registry } from "@server/openApi"; import { isValidCIDR } from "@server/lib/validators"; import { createCustomer } from "#dynamic/lib/billing"; import { usageService } from "@server/lib/billing/usageService"; -import { FeatureId } from "@server/lib/billing"; +import { FeatureId, limitsService, freeLimitSet } from "@server/lib/billing"; import { build } from "@server/build"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { doCidrsOverlap } from "@server/lib/ip"; import { generateCA } from "@server/private/lib/sshCA"; import { encrypt } from "@server/lib/crypto"; +const validOrgIdRegex = /^[a-z0-9_]+(-[a-z0-9_]+)*$/; + const createOrgSchema = z.strictObject({ - orgId: z.string(), + orgId: z + .string() + .min(1, "Organization ID is required") + .max(32, "Organization ID must be at most 32 characters") + .refine((val) => validOrgIdRegex.test(val), { + message: + "Organization ID must contain only lowercase letters, numbers, underscores, and single hyphens (no leading, trailing, or consecutive hyphens)" + }), name: z.string().min(1).max(255), subnet: z // .union([z.cidrv4(), z.cidrv6()]) @@ -110,6 +119,7 @@ export async function createOrg( // ) // ); // } + // // make sure the orgId is unique const orgExists = await db @@ -136,8 +146,71 @@ export async function createOrg( ); } + let isFirstOrg: boolean | null = null; + let billingOrgIdForNewOrg: string | null = null; + if (build === "saas" && req.user) { + const ownedOrgs = await db + .select() + .from(userOrgs) + .where( + and( + eq(userOrgs.userId, req.user.userId), + eq(userOrgs.isOwner, true) + ) + ); + if (ownedOrgs.length === 0) { + isFirstOrg = true; + } else { + isFirstOrg = false; + const [billingOrg] = await db + .select({ orgId: orgs.orgId }) + .from(orgs) + .innerJoin(userOrgs, eq(orgs.orgId, userOrgs.orgId)) + .where( + and( + eq(userOrgs.userId, req.user.userId), + eq(userOrgs.isOwner, true), + eq(orgs.isBillingOrg, true) + ) + ) + .limit(1); + if (billingOrg) { + billingOrgIdForNewOrg = billingOrg.orgId; + } + } + } + + if (build == "saas" && billingOrgIdForNewOrg) { + const usage = await usageService.getUsage(billingOrgIdForNewOrg, FeatureId.ORGINIZATIONS); + if (!usage) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + "No usage data found for this organization" + ) + ); + } + const rejectOrgs = await usageService.checkLimitSet( + billingOrgIdForNewOrg, + FeatureId.ORGINIZATIONS, + { + ...usage, + instantaneousValue: (usage.instantaneousValue || 0) + 1 + } // We need to add one to know if we are violating the limit + ); + if (rejectOrgs) { + return next( + createHttpError( + HttpCode.FORBIDDEN, + "Organization limit exceeded. Please upgrade your plan." + ) + ); + } + } + let error = ""; let org: Org | null = null; + let numOrgs: number | null = null; await db.transaction(async (trx) => { const allDomains = await trx @@ -145,11 +218,21 @@ export async function createOrg( .from(domains) .where(eq(domains.configManaged, true)); - // // Generate SSH CA keys for the org + // Generate SSH CA keys for the org // const ca = generateCA(`${orgId}-ca`); // const encryptionKey = config.getRawConfig().server.secret!; // const encryptedCaPrivateKey = encrypt(ca.privateKeyPem, encryptionKey); + const saasBillingFields = + build === "saas" && req.user && isFirstOrg !== null + ? isFirstOrg + ? { isBillingOrg: true as const, billingOrgId: orgId } // if this is the first org, it becomes the billing org for itself + : { + isBillingOrg: false as const, + billingOrgId: billingOrgIdForNewOrg + } + : {}; + const newOrg = await trx .insert(orgs) .values({ @@ -159,7 +242,8 @@ export async function createOrg( utilitySubnet, createdAt: new Date().toISOString(), // sshCaPrivateKey: encryptedCaPrivateKey, - // sshCaPublicKey: ca.publicKeyOpenSSH + // sshCaPublicKey: ca.publicKeyOpenSSH, + ...saasBillingFields }) .returning(); @@ -261,6 +345,17 @@ export async function createOrg( ); await calculateUserClientsForOrgs(ownerUserId, trx); + + if (billingOrgIdForNewOrg) { + const [numOrgsResult] = await trx + .select({ count: count() }) + .from(orgs) + .where(eq(orgs.billingOrgId, billingOrgIdForNewOrg)); // all the billable orgs including the primary org that is the billing org itself + + numOrgs = numOrgsResult.count; + } else { + numOrgs = 1; // we only have one org if there is no billing org found out + } }); if (!org) { @@ -276,8 +371,8 @@ export async function createOrg( return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, error)); } - if (build == "saas") { - // make sure we have the stripe customer + if (build === "saas" && isFirstOrg === true) { + await limitsService.applyLimitSetToOrg(orgId, freeLimitSet); const customerId = await createCustomer(orgId, req.user?.email); if (customerId) { await usageService.updateCount( @@ -289,6 +384,14 @@ export async function createOrg( } } + if (numOrgs) { + usageService.updateCount( + billingOrgIdForNewOrg || orgId, + FeatureId.ORGINIZATIONS, + numOrgs + ); + } + return response(res, { data: org, success: true, diff --git a/server/routers/org/deleteOrg.ts b/server/routers/org/deleteOrg.ts index 0e5b87a2..7de02162 100644 --- a/server/routers/org/deleteOrg.ts +++ b/server/routers/org/deleteOrg.ts @@ -7,6 +7,8 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { OpenAPITags, registry } from "@server/openApi"; import { deleteOrgById, sendTerminationMessages } from "@server/lib/deleteOrg"; +import { db, userOrgs, orgs } from "@server/db"; +import { eq, and } from "drizzle-orm"; const deleteOrgSchema = z.strictObject({ orgId: z.string() @@ -41,6 +43,48 @@ export async function deleteOrg( ); } const { orgId } = parsedParams.data; + + const [data] = await db + .select() + .from(userOrgs) + .innerJoin(orgs, eq(userOrgs.orgId, orgs.orgId)) + .where( + and( + eq(userOrgs.orgId, orgId), + eq(userOrgs.userId, req.user!.userId) + ) + ); + + const org = data?.orgs; + const userOrg = data?.userOrgs; + + if (!org || !userOrg) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Organization with ID ${orgId} not found` + ) + ); + } + + if (!userOrg.isOwner) { + return next( + createHttpError( + HttpCode.FORBIDDEN, + "Only organization owners can delete the organization" + ) + ); + } + + if (org.isBillingOrg) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Cannot delete a primary organization" + ) + ); + } + const result = await deleteOrgById(orgId); sendTerminationMessages(result); return response(res, { diff --git a/server/routers/org/listUserOrgs.ts b/server/routers/org/listUserOrgs.ts index 103b1023..301d0203 100644 --- a/server/routers/org/listUserOrgs.ts +++ b/server/routers/org/listUserOrgs.ts @@ -40,7 +40,11 @@ const listOrgsSchema = z.object({ // responses: {} // }); -type ResponseOrg = Org & { isOwner?: boolean; isAdmin?: boolean }; +type ResponseOrg = Org & { + isOwner?: boolean; + isAdmin?: boolean; + isPrimaryOrg?: boolean; +}; export type ListUserOrgsResponse = { orgs: ResponseOrg[]; @@ -132,6 +136,9 @@ export async function listUserOrgs( if (val.roles && val.roles.isAdmin) { res.isAdmin = val.roles.isAdmin; } + if (val.userOrgs?.isOwner && val.orgs?.isBillingOrg) { + res.isPrimaryOrg = val.orgs.isBillingOrg; + } return res; }); diff --git a/server/routers/resource/getUserResources.ts b/server/routers/resource/getUserResources.ts index 3d28da6f..eb5f8a8d 100644 --- a/server/routers/resource/getUserResources.ts +++ b/server/routers/resource/getUserResources.ts @@ -8,7 +8,10 @@ import { userOrgs, resourcePassword, resourcePincode, - resourceWhitelist + resourceWhitelist, + siteResources, + userSiteResources, + roleSiteResources } from "@server/db"; import createHttpError from "http-errors"; import HttpCode from "@server/types/HttpCode"; @@ -57,9 +60,21 @@ export async function getUserResources( .from(roleResources) .where(eq(roleResources.roleId, userRoleId)); - const [directResources, roleResourceResults] = await Promise.all([ + const directSiteResourcesQuery = db + .select({ siteResourceId: userSiteResources.siteResourceId }) + .from(userSiteResources) + .where(eq(userSiteResources.userId, userId)); + + const roleSiteResourcesQuery = db + .select({ siteResourceId: roleSiteResources.siteResourceId }) + .from(roleSiteResources) + .where(eq(roleSiteResources.roleId, userRoleId)); + + const [directResources, roleResourceResults, directSiteResourceResults, roleSiteResourceResults] = await Promise.all([ directResourcesQuery, - roleResourcesQuery + roleResourcesQuery, + directSiteResourcesQuery, + roleSiteResourcesQuery ]); // Combine all accessible resource IDs @@ -68,18 +83,25 @@ export async function getUserResources( ...roleResourceResults.map((r) => r.resourceId) ]; - if (accessibleResourceIds.length === 0) { - return response(res, { - data: { resources: [] }, - success: true, - error: false, - message: "No resources found", - status: HttpCode.OK - }); - } + // Combine all accessible site resource IDs + const accessibleSiteResourceIds = [ + ...directSiteResourceResults.map((r) => r.siteResourceId), + ...roleSiteResourceResults.map((r) => r.siteResourceId) + ]; // Get resource details for accessible resources - const resourcesData = await db + let resourcesData: Array<{ + resourceId: number; + name: string; + fullDomain: string | null; + ssl: boolean; + enabled: boolean; + sso: boolean; + protocol: string; + emailWhitelistEnabled: boolean; + }> = []; + if (accessibleResourceIds.length > 0) { + resourcesData = await db .select({ resourceId: resources.resourceId, name: resources.name, @@ -98,6 +120,40 @@ export async function getUserResources( eq(resources.enabled, true) ) ); + } + + // Get site resource details for accessible site resources + let siteResourcesData: Array<{ + siteResourceId: number; + name: string; + destination: string; + mode: string; + protocol: string | null; + enabled: boolean; + alias: string | null; + aliasAddress: string | null; + }> = []; + if (accessibleSiteResourceIds.length > 0) { + siteResourcesData = await db + .select({ + siteResourceId: siteResources.siteResourceId, + name: siteResources.name, + destination: siteResources.destination, + mode: siteResources.mode, + protocol: siteResources.protocol, + enabled: siteResources.enabled, + alias: siteResources.alias, + aliasAddress: siteResources.aliasAddress + }) + .from(siteResources) + .where( + and( + inArray(siteResources.siteResourceId, accessibleSiteResourceIds), + eq(siteResources.orgId, orgId), + eq(siteResources.enabled, true) + ) + ); + } // Check for password, pincode, and whitelist protection for each resource const resourcesWithAuth = await Promise.all( @@ -161,8 +217,26 @@ export async function getUserResources( }) ); + // Format site resources + const siteResourcesFormatted = siteResourcesData.map((siteResource) => { + return { + siteResourceId: siteResource.siteResourceId, + name: siteResource.name, + destination: siteResource.destination, + mode: siteResource.mode, + protocol: siteResource.protocol, + enabled: siteResource.enabled, + alias: siteResource.alias, + aliasAddress: siteResource.aliasAddress, + type: 'site' as const + }; + }); + return response(res, { - data: { resources: resourcesWithAuth }, + data: { + resources: resourcesWithAuth, + siteResources: siteResourcesFormatted + }, success: true, error: false, message: "User resources retrieved successfully", @@ -190,5 +264,16 @@ export type GetUserResourcesResponse = { protected: boolean; protocol: string; }>; + siteResources: Array<{ + siteResourceId: number; + name: string; + destination: string; + mode: string; + protocol: string | null; + enabled: boolean; + alias: string | null; + aliasAddress: string | null; + type: 'site'; + }>; }; }; diff --git a/server/routers/site/createSite.ts b/server/routers/site/createSite.ts index 797bf2ae..ea4bc3e8 100644 --- a/server/routers/site/createSite.ts +++ b/server/routers/site/createSite.ts @@ -6,7 +6,7 @@ import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; -import { eq, and } from "drizzle-orm"; +import { eq, and, count } from "drizzle-orm"; import { getUniqueSiteName } from "../../db/names"; import { addPeer } from "../gerbil/peers"; import { fromError } from "zod-validation-error"; @@ -288,7 +288,6 @@ export async function createSite( const niceId = await getUniqueSiteName(orgId); let newSite: Site | undefined; - let numSites: Site[] | undefined; await db.transaction(async (trx) => { if (type == "newt") { [newSite] = await trx @@ -443,20 +442,9 @@ export async function createSite( }); } - numSites = await trx - .select() - .from(sites) - .where(eq(sites.orgId, orgId)); + await usageService.add(orgId, FeatureId.SITES, 1, trx); }); - if (numSites) { - await usageService.updateCount( - orgId, - FeatureId.SITES, - numSites.length - ); - } - if (!newSite) { return next( createHttpError( diff --git a/server/routers/site/deleteSite.ts b/server/routers/site/deleteSite.ts index 2ce900fd..58757253 100644 --- a/server/routers/site/deleteSite.ts +++ b/server/routers/site/deleteSite.ts @@ -64,7 +64,6 @@ export async function deleteSite( } let deletedNewtId: string | null = null; - let numSites: Site[] | undefined; await db.transaction(async (trx) => { if (site.type == "wireguard") { @@ -103,19 +102,9 @@ export async function deleteSite( await trx.delete(sites).where(eq(sites.siteId, siteId)); - numSites = await trx - .select() - .from(sites) - .where(eq(sites.orgId, site.orgId)); + await usageService.add(site.orgId, FeatureId.SITES, -1, trx); }); - if (numSites) { - await usageService.updateCount( - site.orgId, - FeatureId.SITES, - numSites.length - ); - } // Send termination message outside of transaction to prevent blocking if (deletedNewtId) { const payload = { diff --git a/server/routers/user/acceptInvite.ts b/server/routers/user/acceptInvite.ts index 74f025ae..388db4a3 100644 --- a/server/routers/user/acceptInvite.ts +++ b/server/routers/user/acceptInvite.ts @@ -1,8 +1,8 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db, UserOrg } from "@server/db"; +import { db, orgs, UserOrg } from "@server/db"; import { roles, userInvites, userOrgs, users } from "@server/db"; -import { eq } from "drizzle-orm"; +import { eq, and, inArray, ne } from "drizzle-orm"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; @@ -14,6 +14,7 @@ import { usageService } from "@server/lib/billing/usageService"; import { FeatureId } from "@server/lib/billing"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { build } from "@server/build"; +import { assignUserToOrg } from "@server/lib/userOrg"; const acceptInviteBodySchema = z.strictObject({ token: z.string(), @@ -125,8 +126,22 @@ export async function acceptInvite( } } + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, existingInvite.orgId)) + .limit(1); + + if (!org) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Organization does not exist. Please contact an admin." + ) + ); + } + let roleId: number; - let totalUsers: UserOrg[] | undefined; // get the role to make sure it exists const existingRole = await db .select() @@ -146,12 +161,15 @@ export async function acceptInvite( } await db.transaction(async (trx) => { - // add the user to the org - await trx.insert(userOrgs).values({ - userId: existingUser[0].userId, - orgId: existingInvite.orgId, - roleId: existingInvite.roleId - }); + await assignUserToOrg( + org, + { + userId: existingUser[0].userId, + orgId: existingInvite.orgId, + roleId: existingInvite.roleId + }, + trx + ); // delete the invite await trx @@ -160,25 +178,11 @@ export async function acceptInvite( await calculateUserClientsForOrgs(existingUser[0].userId, trx); - // Get the total number of users in the org now - totalUsers = await trx - .select() - .from(userOrgs) - .where(eq(userOrgs.orgId, existingInvite.orgId)); - logger.debug( - `User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}. Total users in org: ${totalUsers.length}` + `User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}` ); }); - if (totalUsers) { - await usageService.updateCount( - existingInvite.orgId, - FeatureId.USERS, - totalUsers.length - ); - } - return response(res, { data: { accepted: true, orgId: existingInvite.orgId }, success: true, diff --git a/server/routers/user/createOrgUser.ts b/server/routers/user/createOrgUser.ts index d0515e71..b39ea22e 100644 --- a/server/routers/user/createOrgUser.ts +++ b/server/routers/user/createOrgUser.ts @@ -6,8 +6,8 @@ import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { OpenAPITags, registry } from "@server/openApi"; -import { db, UserOrg } from "@server/db"; -import { and, eq } from "drizzle-orm"; +import { db, orgs, UserOrg } from "@server/db"; +import { and, eq, inArray, ne } from "drizzle-orm"; import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db"; import { generateId } from "@server/auth/sessions/app"; import { usageService } from "@server/lib/billing/usageService"; @@ -16,6 +16,7 @@ import { build } from "@server/build"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; import { isSubscribed } from "#dynamic/lib/isSubscribed"; import { tierMatrix } from "@server/lib/billing/tierMatrix"; +import { assignUserToOrg } from "@server/lib/userOrg"; const paramsSchema = z.strictObject({ orgId: z.string().nonempty() @@ -151,6 +152,21 @@ export async function createOrgUser( ); } + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, orgId)) + .limit(1); + + if (!org) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + "Organization not found" + ) + ); + } + const [idpRes] = await db .select() .from(idp) @@ -172,8 +188,6 @@ export async function createOrgUser( ); } - let orgUsers: UserOrg[] | undefined; - await db.transaction(async (trx) => { const [existingUser] = await trx .select() @@ -207,15 +221,12 @@ export async function createOrgUser( ); } - await trx - .insert(userOrgs) - .values({ - orgId, - userId: existingUser.userId, - roleId: role.roleId, - autoProvisioned: false - }) - .returning(); + await assignUserToOrg(org, { + orgId, + userId: existingUser.userId, + roleId: role.roleId, + autoProvisioned: false + }, trx); } else { userId = generateId(15); @@ -233,33 +244,16 @@ export async function createOrgUser( }) .returning(); - await trx - .insert(userOrgs) - .values({ + await assignUserToOrg(org, { orgId, userId: newUser.userId, roleId: role.roleId, autoProvisioned: false - }) - .returning(); + }, trx); } - // List all of the users in the org - orgUsers = await trx - .select() - .from(userOrgs) - .where(eq(userOrgs.orgId, orgId)); - await calculateUserClientsForOrgs(userId, trx); }); - - if (orgUsers) { - await usageService.updateCount( - orgId, - FeatureId.USERS, - orgUsers.length - ); - } } else { return next( createHttpError(HttpCode.BAD_REQUEST, "User type is required") diff --git a/server/routers/user/removeUserOrg.ts b/server/routers/user/removeUserOrg.ts index 768d5fff..4c321ad3 100644 --- a/server/routers/user/removeUserOrg.ts +++ b/server/routers/user/removeUserOrg.ts @@ -1,8 +1,16 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db, resources, sites, UserOrg } from "@server/db"; +import { + db, + orgs, + resources, + siteResources, + sites, + UserOrg, + userSiteResources +} from "@server/db"; import { userOrgs, userResources, users, userSites } from "@server/db"; -import { and, count, eq, exists } from "drizzle-orm"; +import { and, count, eq, exists, inArray } from "drizzle-orm"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; @@ -14,6 +22,7 @@ import { FeatureId } from "@server/lib/billing"; import { build } from "@server/build"; import { UserType } from "@server/types/UserTypes"; import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs"; +import { removeUserFromOrg } from "@server/lib/userOrg"; const removeUserSchema = z.strictObject({ userId: z.string(), @@ -50,16 +59,16 @@ export async function removeUserOrg( const { userId, orgId } = parsedParams.data; // get the user first - const user = await db + const [user] = await db .select() .from(userOrgs) .where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId))); - if (!user || user.length === 0) { + if (!user) { return next(createHttpError(HttpCode.NOT_FOUND, "User not found")); } - if (user[0].isOwner) { + if (user.isOwner) { return next( createHttpError( HttpCode.BAD_REQUEST, @@ -68,56 +77,20 @@ export async function removeUserOrg( ); } - let userCount: UserOrg[] | undefined; + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, orgId)) + .limit(1); + + if (!org) { + return next( + createHttpError(HttpCode.NOT_FOUND, "Organization not found") + ); + } await db.transaction(async (trx) => { - await trx - .delete(userOrgs) - .where( - and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)) - ); - - await db.delete(userResources).where( - and( - eq(userResources.userId, userId), - exists( - db - .select() - .from(resources) - .where( - and( - eq( - resources.resourceId, - userResources.resourceId - ), - eq(resources.orgId, orgId) - ) - ) - ) - ) - ); - - await db.delete(userSites).where( - and( - eq(userSites.userId, userId), - exists( - db - .select() - .from(sites) - .where( - and( - eq(sites.siteId, userSites.siteId), - eq(sites.orgId, orgId) - ) - ) - ) - ) - ); - - userCount = await trx - .select() - .from(userOrgs) - .where(eq(userOrgs.orgId, orgId)); + await removeUserFromOrg(org, userId, trx); // if (build === "saas") { // const [rootUser] = await trx @@ -139,14 +112,6 @@ export async function removeUserOrg( await calculateUserClientsForOrgs(userId, trx); }); - if (userCount) { - await usageService.updateCount( - orgId, - FeatureId.USERS, - userCount.length - ); - } - return response(res, { data: null, success: true, diff --git a/src/app/[orgId]/settings/(private)/billing/layout.tsx b/src/app/[orgId]/settings/(private)/billing/layout.tsx index c4048bcc..69c3da48 100644 --- a/src/app/[orgId]/settings/(private)/billing/layout.tsx +++ b/src/app/[orgId]/settings/(private)/billing/layout.tsx @@ -6,6 +6,7 @@ import { redirect } from "next/navigation"; import { getTranslations } from "next-intl/server"; import { getCachedOrgUser } from "@app/lib/api/getCachedOrgUser"; import { getCachedOrg } from "@app/lib/api/getCachedOrg"; +import { build } from "@server/build"; type BillingSettingsProps = { children: React.ReactNode; @@ -17,6 +18,9 @@ export default async function BillingSettingsPage({ params }: BillingSettingsProps) { const { orgId } = await params; + if (build !== "saas") { + redirect(`/${orgId}/settings`); + } const user = await verifySession(); @@ -40,6 +44,10 @@ export default async function BillingSettingsPage({ redirect(`/${orgId}`); } + if (!(org?.org?.isBillingOrg && orgUser?.isOwner)) { + redirect(`/${orgId}`); + } + const t = await getTranslations(); return ( diff --git a/src/app/[orgId]/settings/(private)/billing/page.tsx b/src/app/[orgId]/settings/(private)/billing/page.tsx index b108d461..bad8bda2 100644 --- a/src/app/[orgId]/settings/(private)/billing/page.tsx +++ b/src/app/[orgId]/settings/(private)/billing/page.tsx @@ -110,37 +110,42 @@ const planOptions: PlanOption[] = [ // Tier limits mapping derived from limit sets const tierLimits: Record< Tier | "basic", - { users: number; sites: number; domains: number; remoteNodes: number } + { users: number; sites: number; domains: number; remoteNodes: number; organizations: number } > = { basic: { users: freeLimitSet[FeatureId.USERS]?.value ?? 0, sites: freeLimitSet[FeatureId.SITES]?.value ?? 0, domains: freeLimitSet[FeatureId.DOMAINS]?.value ?? 0, - remoteNodes: freeLimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 + remoteNodes: freeLimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0, + organizations: freeLimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0 }, tier1: { users: tier1LimitSet[FeatureId.USERS]?.value ?? 0, sites: tier1LimitSet[FeatureId.SITES]?.value ?? 0, domains: tier1LimitSet[FeatureId.DOMAINS]?.value ?? 0, - remoteNodes: tier1LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 + remoteNodes: tier1LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0, + organizations: tier1LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0 }, tier2: { users: tier2LimitSet[FeatureId.USERS]?.value ?? 0, sites: tier2LimitSet[FeatureId.SITES]?.value ?? 0, domains: tier2LimitSet[FeatureId.DOMAINS]?.value ?? 0, - remoteNodes: tier2LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 + remoteNodes: tier2LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0, + organizations: tier2LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0 }, tier3: { users: tier3LimitSet[FeatureId.USERS]?.value ?? 0, sites: tier3LimitSet[FeatureId.SITES]?.value ?? 0, domains: tier3LimitSet[FeatureId.DOMAINS]?.value ?? 0, - remoteNodes: tier3LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0 + remoteNodes: tier3LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0, + organizations: tier3LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0 }, enterprise: { users: 0, // Custom for enterprise sites: 0, // Custom for enterprise domains: 0, // Custom for enterprise - remoteNodes: 0 // Custom for enterprise + remoteNodes: 0, // Custom for enterprise + organizations: 0 // Custom for enterprise } }; @@ -179,6 +184,7 @@ export default function BillingPage() { const SITES = "sites"; const DOMAINS = "domains"; const REMOTE_EXIT_NODES = "remoteExitNodes"; + const ORGINIZATIONS = "organizations"; // Confirmation dialog state const [showConfirmDialog, setShowConfirmDialog] = useState(false); @@ -619,6 +625,16 @@ export default function BillingPage() { }); } + // Check organizations + const organizationsUsage = getUsageValue(ORGINIZATIONS); + if (limits.organizations > 0 && organizationsUsage > limits.organizations) { + violations.push({ + feature: "Organizations", + currentUsage: organizationsUsage, + newLimit: limits.organizations + }); + } + return violations; }; @@ -752,7 +768,7 @@ export default function BillingPage() {
{t("billingMaximumLimits") || "Maximum Limits"}
- + {t("billingUsers") || "Users"} @@ -855,6 +871,41 @@ export default function BillingPage() { )} + + + {t("billingOrganizations") || + "Organizations"} + + + {isOverLimit(ORGINIZATIONS) ? ( + + + + + {getLimitValue(ORGINIZATIONS) ?? + t("billingUnlimited") ?? + "∞"}{" "} + {getLimitValue(ORGINIZATIONS) !== + null && "orgs"} + + + +

{t("billingUsageExceedsLimit", { current: getUsageValue(ORGINIZATIONS), limit: getLimitValue(ORGINIZATIONS) ?? 0 }) || `Current usage (${getUsageValue(ORGINIZATIONS)}) exceeds limit (${getLimitValue(ORGINIZATIONS)})`}

+
+
+ ) : ( + <> + {getLimitValue(ORGINIZATIONS) ?? + t("billingUnlimited") ?? + "∞"}{" "} + {getLimitValue(ORGINIZATIONS) !== + null && "orgs"} + + )} +
+
{t("billingRemoteNodes") || @@ -872,7 +923,7 @@ export default function BillingPage() { t("billingUnlimited") ?? "∞"}{" "} {getLimitValue(REMOTE_EXIT_NODES) !== - null && "remote nodes"} + null && "nodes"} @@ -885,7 +936,7 @@ export default function BillingPage() { t("billingUnlimited") ?? "∞"}{" "} {getLimitValue(REMOTE_EXIT_NODES) !== - null && "remote nodes"} + null && "nodes"} )} @@ -1016,6 +1067,17 @@ export default function BillingPage() { "Domains"} +
+ + + { + tierLimits[pendingTier.tier] + .organizations + }{" "} + {t("billingOrganizations") || + "Organizations"} + +
diff --git a/src/app/[orgId]/settings/(private)/license/layout.tsx b/src/app/[orgId]/settings/(private)/license/layout.tsx index 9083bb81..453b3372 100644 --- a/src/app/[orgId]/settings/(private)/license/layout.tsx +++ b/src/app/[orgId]/settings/(private)/license/layout.tsx @@ -4,6 +4,8 @@ import { redirect } from "next/navigation"; import { cache } from "react"; import { getTranslations } from "next-intl/server"; import { build } from "@server/build"; +import { getCachedOrgUser } from "@app/lib/api/getCachedOrgUser"; +import { getCachedOrg } from "@app/lib/api/getCachedOrg"; type LicensesSettingsProps = { children: React.ReactNode; @@ -27,6 +29,26 @@ export default async function LicensesSetingsLayoutProps({ redirect(`/`); } + let orgUser = null; + try { + const res = await getCachedOrgUser(orgId, user.userId); + orgUser = res.data.data; + } catch { + redirect(`/${orgId}`); + } + + let org = null; + try { + const res = await getCachedOrg(orgId); + org = res.data.data; + } catch { + redirect(`/${orgId}`); + } + + if (!org?.org?.isBillingOrg || !orgUser?.isOwner) { + redirect(`/${orgId}`); + } + const t = await getTranslations(); return ( diff --git a/src/app/[orgId]/settings/general/page.tsx b/src/app/[orgId]/settings/general/page.tsx index 0b3ae3d5..0a2ed39b 100644 --- a/src/app/[orgId]/settings/general/page.tsx +++ b/src/app/[orgId]/settings/general/page.tsx @@ -3,11 +3,7 @@ import ConfirmDeleteDialog from "@app/components/ConfirmDeleteDialog"; import { Button } from "@app/components/ui/button"; import { useOrgContext } from "@app/hooks/useOrgContext"; import { toast } from "@app/hooks/useToast"; -import { - useState, - useTransition, - useActionState -} from "react"; +import { useState, useTransition, useActionState } from "react"; import { Form, FormControl, @@ -54,7 +50,7 @@ export default function GeneralPage() { return ( - {build !== "saas" && } + {!org.org.isBillingOrg && } ); } diff --git a/src/app/[orgId]/settings/layout.tsx b/src/app/[orgId]/settings/layout.tsx index 34ed3ac2..8ee7b1dc 100644 --- a/src/app/[orgId]/settings/layout.tsx +++ b/src/app/[orgId]/settings/layout.tsx @@ -77,12 +77,16 @@ export default async function SettingsLayout(props: SettingsLayoutProps) { } } catch (e) {} + const primaryOrg = orgs.find((o) => o.orgId === params.orgId)?.isPrimaryOrg; + return ( {children} diff --git a/src/app/auth/signup/page.tsx b/src/app/auth/signup/page.tsx index 9ab7b7e6..f51ac904 100644 --- a/src/app/auth/signup/page.tsx +++ b/src/app/auth/signup/page.tsx @@ -15,6 +15,7 @@ export default async function Page(props: { redirect: string | undefined; email: string | undefined; fromSmartLogin: string | undefined; + skipVerificationEmail: string | undefined; }>; }) { const searchParams = await props.searchParams; @@ -75,6 +76,10 @@ export default async function Page(props: { inviteId={inviteId} emailParam={searchParams.email} fromSmartLogin={searchParams.fromSmartLogin === "true"} + skipVerificationEmail={ + searchParams.skipVerificationEmail === "true" || + searchParams.skipVerificationEmail === "1" + } />

diff --git a/src/app/navigation.tsx b/src/app/navigation.tsx index 7df4364a..be3ad7d3 100644 --- a/src/app/navigation.tsx +++ b/src/app/navigation.tsx @@ -31,6 +31,10 @@ export type SidebarNavSection = { items: SidebarNavItem[]; }; +export type OrgNavSectionsOptions = { + isPrimaryOrg?: boolean; +}; + // Merged from 'user-management-and-resources' branch export const orgLangingNavItems: SidebarNavItem[] = [ { @@ -40,7 +44,10 @@ export const orgLangingNavItems: SidebarNavItem[] = [ } ]; -export const orgNavSections = (env?: Env): SidebarNavSection[] => [ +export const orgNavSections = ( + env?: Env, + options?: OrgNavSectionsOptions +): SidebarNavSection[] => [ { heading: "sidebarGeneral", items: [ @@ -214,28 +221,28 @@ export const orgNavSections = (env?: Env): SidebarNavSection[] => [ title: "sidebarSettings", href: "/{orgId}/settings/general", icon: - }, - - ...(build == "saas" - ? [ + } + ] + }, + ...(build == "saas" && options?.isPrimaryOrg + ? [ + { + heading: "sidebarBillingAndLicenses", + items: [ { title: "sidebarBilling", href: "/{orgId}/settings/billing", icon: - } - ] - : []), - ...(build == "saas" - ? [ + }, { title: "sidebarEnterpriseLicenses", href: "/{orgId}/settings/license", icon: } ] - : []) - ] - } + } + ] + : []) ]; export const adminNavSections = (env?: Env): SidebarNavSection[] => [ diff --git a/src/app/page.tsx b/src/app/page.tsx index df1a81df..f6f30276 100644 --- a/src/app/page.tsx +++ b/src/app/page.tsx @@ -73,7 +73,7 @@ export default async function Page(props: { if (!orgs.length) { if (!env.flags.disableUserCreateOrg || user.serverAdmin) { - redirect("/setup"); + redirect("/setup?firstOrg"); } } @@ -86,6 +86,14 @@ export default async function Page(props: { targetOrgId = lastOrgCookie; } else { let ownedOrg = orgs.find((org) => org.isOwner); + let primaryOrg = orgs.find((org) => org.isPrimaryOrg); + if (!ownedOrg) { + if (primaryOrg) { + ownedOrg = primaryOrg; + } else { + ownedOrg = orgs[0]; + } + } if (!ownedOrg) { ownedOrg = orgs[0]; } diff --git a/src/app/setup/page.tsx b/src/app/setup/page.tsx index c8b2af19..c7e6de6a 100644 --- a/src/app/setup/page.tsx +++ b/src/app/setup/page.tsx @@ -4,19 +4,14 @@ import { Button } from "@/components/ui/button"; import { Input } from "@/components/ui/input"; import { toast } from "@app/hooks/useToast"; import { useCallback, useEffect, useState } from "react"; -import { - Card, - CardContent, - CardDescription, - CardHeader, - CardTitle -} from "@app/components/ui/card"; import { formatAxiosError } from "@app/lib/api"; import { createApiClient } from "@app/lib/api"; import { useEnvContext } from "@app/hooks/useEnvContext"; +import { useUserContext } from "@app/hooks/useUserContext"; +import { build } from "@server/build"; import { Separator } from "@/components/ui/separator"; import { z } from "zod"; -import { useRouter } from "next/navigation"; +import { useRouter, useSearchParams } from "next/navigation"; import { useForm } from "react-hook-form"; import { zodResolver } from "@hookform/resolvers/zod"; import { @@ -35,7 +30,7 @@ import { CollapsibleContent, CollapsibleTrigger } from "@app/components/ui/collapsible"; -import { ChevronsUpDown } from "lucide-react"; +import { ArrowRight, ChevronsUpDown } from "lucide-react"; import { cn } from "@app/lib/cn"; type Step = "org" | "site" | "resources"; @@ -45,6 +40,7 @@ export default function StepperForm() { const [orgIdTaken, setOrgIdTaken] = useState(false); const t = useTranslations(); const { env } = useEnvContext(); + const { user } = useUserContext(); const [loading, setLoading] = useState(false); const [isChecked, setIsChecked] = useState(false); @@ -54,7 +50,10 @@ export default function StepperForm() { const orgSchema = z.object({ orgName: z.string().min(1, { message: t("orgNameRequired") }), - orgId: z.string().min(1, { message: t("orgIdRequired") }), + orgId: z + .string() + .min(1, { message: t("orgIdRequired") }) + .max(32, { message: t("orgIdMaxLength") }), subnet: z.string().min(1, { message: t("subnetRequired") }), utilitySubnet: z.string().min(1, { message: t("subnetRequired") }) }); @@ -71,12 +70,27 @@ export default function StepperForm() { const api = createApiClient(useEnvContext()); const router = useRouter(); + const searchParams = useSearchParams(); + const isFirstOrg = searchParams.get("firstOrg") != null; // Fetch default subnet on component mount useEffect(() => { fetchDefaultSubnet(); }, []); + // Prefill org name and id when build is saas and firstOrg query param is set + useEffect(() => { + if (build !== "saas" || !user || !isFirstOrg) return; + + const orgName = user.email + ? `${user.email}'s Organization` + : "My Organization"; + const orgId = `org_${user.userId}`; + orgForm.setValue("orgName", orgName); + orgForm.setValue("orgId", orgId); + debouncedCheckOrgIdAvailability(orgId); + }, []); + const fetchDefaultSubnet = async () => { try { const res = await api.get(`/pick-org-defaults`); @@ -129,6 +143,16 @@ export default function StepperForm() { .replace(/^-+|-+$/g, ""); }; + const sanitizeOrgId = (value: string) => { + return value + .toLowerCase() + .replace(/\s+/g, "-") + .replace(/[^a-z0-9_-]/g, "") + .replace(/-+/g, "-") + .replace(/^-+|-+$/g, "") + .slice(0, 32); + }; + async function orgSubmit(values: z.infer) { if (orgIdTaken) { return; @@ -161,263 +185,254 @@ export default function StepperForm() { } return ( - <> - - - {t("setupNewOrg")} - {t("setupCreate")} - - -

-
-
-
- 1 -
- - {t("setupCreateOrg")} - -
-
-
- 2 -
- - {t("siteCreate")} - -
-
-
- 3 -
- - {t("setupCreateResources")} - -
-
+
+
+

+ {t("setupNewOrg")} +

+

+ {t("setupCreate")} +

+
+
+
+
+ 1 +
+ + {t("setupCreateOrg")} + +
+
+
+ 2 +
+ + {t("siteCreate")} + +
+
+
+ 3 +
+ + {t("setupCreateResources")} + +
+
- + - {currentStep === "org" && ( -
- - ( - - - {t("setupOrgName")} - - - { - // Prevent "/" in orgName input - const sanitizedValue = - e.target.value.replace( - /\//g, - "-" - ); - const orgId = - generateId( - sanitizedValue - ); - orgForm.setValue( - "orgId", - orgId - ); - orgForm.setValue( - "orgName", - sanitizedValue - ); - debouncedCheckOrgIdAvailability( - orgId - ); - }} - value={field.value.replace( - /\//g, - "-" - )} - /> - - - - {t("orgDisplayName")} - - - )} - /> - ( - - - {t("orgId")} - - - - - - - {t( - "setupIdentifierMessage" - )} - - - )} - /> + {currentStep === "org" && ( + + + ( + + {t("setupOrgName")} + + { + // Prevent "/" in orgName input + const sanitizedValue = + e.target.value.replace( + /\//g, + "-" + ); + const orgId = + generateId(sanitizedValue); + orgForm.setValue( + "orgId", + orgId + ); + orgForm.setValue( + "orgName", + sanitizedValue + ); + debouncedCheckOrgIdAvailability( + orgId + ); + }} + value={field.value.replace( + /\//g, + "-" + )} + /> + + + + {t("orgDisplayName")} + + + )} + /> + ( + + {t("orgId")} + + { + const value = sanitizeOrgId( + e.target.value + ); + field.onChange(value); + setOrgIdTaken(false); + if (value) { + debouncedCheckOrgIdAvailability( + value + ); + } + }} + /> + + + + {t("setupIdentifierMessage")} + + + )} + /> - +
+ + - +

+ {t("advancedSettings")} +

+
+ + + {t("toggle")} +
- - ( - - - {t( - "setupSubnetAdvanced" - )} - - - - - - - {t( - "setupSubnetDescription" - )} - - + + +
+ + ( + + + {t("setupSubnetAdvanced")} + + + + + + + {t("setupSubnetDescription")} + + + )} + /> + + ( + + + {t("setupUtilitySubnet")} + + + + + + + {t( + "setupUtilitySubnetDescription" )} - /> + + + )} + /> + +
- ( - - - {t( - "setupUtilitySubnet" - )} - - - - - - - {t( - "setupUtilitySubnetDescription" - )} - - - )} - /> - - + {orgIdTaken && !orgCreated ? ( + + + {t("setupErrorIdentifier")} + + + ) : null} - {orgIdTaken && !orgCreated ? ( - - - {t("setupErrorIdentifier")} - - - ) : null} + {/* Error Alert removed, errors now shown as toast */} - {/* Error Alert removed, errors now shown as toast */} - -
- -
- - - )} -
- - - +
+ +
+ + + )} +
); } diff --git a/src/components/LayoutSidebar.tsx b/src/components/LayoutSidebar.tsx index 15951402..3095b1fd 100644 --- a/src/components/LayoutSidebar.tsx +++ b/src/components/LayoutSidebar.tsx @@ -189,10 +189,12 @@ export function LayoutSidebar({
- {canShowProductUpdates && ( + {canShowProductUpdates ? (
+ ) : ( +
)} {build === "enterprise" && ( diff --git a/src/components/MemberResourcesPortal.tsx b/src/components/MemberResourcesPortal.tsx index 4d3a7717..93456b12 100644 --- a/src/components/MemberResourcesPortal.tsx +++ b/src/components/MemberResourcesPortal.tsx @@ -58,6 +58,18 @@ type Resource = { siteName?: string | null; }; +type SiteResource = { + siteResourceId: number; + name: string; + destination: string; + mode: string; + protocol: string | null; + enabled: boolean; + alias: string | null; + aliasAddress: string | null; + type: 'site'; +}; + type MemberResourcesPortalProps = { orgId: string; }; @@ -334,7 +346,9 @@ export default function MemberResourcesPortal({ const { toast } = useToast(); const [resources, setResources] = useState([]); + const [siteResources, setSiteResources] = useState([]); const [filteredResources, setFilteredResources] = useState([]); + const [filteredSiteResources, setFilteredSiteResources] = useState([]); const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const [searchQuery, setSearchQuery] = useState(""); @@ -360,7 +374,9 @@ export default function MemberResourcesPortal({ if (response.data.success) { setResources(response.data.data.resources); + setSiteResources(response.data.data.siteResources || []); setFilteredResources(response.data.data.resources); + setFilteredSiteResources(response.data.data.siteResources || []); } else { setError("Failed to load resources"); } @@ -417,17 +433,61 @@ export default function MemberResourcesPortal({ setFilteredResources(filtered); + // Filter and sort site resources + const filteredSites = siteResources.filter( + (resource) => + resource.name + .toLowerCase() + .includes(searchQuery.toLowerCase()) || + resource.destination + .toLowerCase() + .includes(searchQuery.toLowerCase()) + ); + + // Sort site resources + filteredSites.sort((a, b) => { + switch (sortBy) { + case "name-asc": + return a.name.localeCompare(b.name); + case "name-desc": + return b.name.localeCompare(a.name); + case "domain-asc": + case "domain-desc": + // Sort by destination for site resources + const destCompare = sortBy === "domain-asc" + ? a.destination.localeCompare(b.destination) + : b.destination.localeCompare(a.destination); + return destCompare; + case "status-enabled": + return b.enabled ? 1 : -1; + case "status-disabled": + return a.enabled ? 1 : -1; + default: + return a.name.localeCompare(b.name); + } + }); + + setFilteredSiteResources(filteredSites); + // Reset to first page when search/sort changes setCurrentPage(1); - }, [resources, searchQuery, sortBy]); + }, [resources, siteResources, searchQuery, sortBy]); // Calculate pagination - const totalPages = Math.ceil(filteredResources.length / itemsPerPage); + const totalItems = filteredResources.length + filteredSiteResources.length; + const totalPages = Math.ceil(totalItems / itemsPerPage); const startIndex = (currentPage - 1) * itemsPerPage; const paginatedResources = filteredResources.slice( startIndex, startIndex + itemsPerPage ); + const remainingSlots = itemsPerPage - paginatedResources.length; + const paginatedSiteResources = remainingSlots > 0 + ? filteredSiteResources.slice( + Math.max(0, startIndex - filteredResources.length), + Math.max(0, startIndex - filteredResources.length) + remainingSlots + ) + : []; const handleOpenResource = (resource: Resource) => { // Open the resource in a new tab @@ -575,7 +635,7 @@ export default function MemberResourcesPortal({
{/* Resources Content */} - {filteredResources.length === 0 ? ( + {filteredResources.length === 0 && filteredSiteResources.length === 0 ? ( /* Enhanced Empty State */ @@ -623,9 +683,20 @@ export default function MemberResourcesPortal({ ) : ( <> - {/* Resources Grid */} -
- {paginatedResources.map((resource) => ( + {/* Public Resources Section */} + {paginatedResources.length > 0 && ( + <> +
+

+ + Public Resources +

+

+ Web applications and services accessible via browser +

+
+
+ {paginatedResources.map((resource) => (
@@ -702,13 +773,167 @@ export default function MemberResourcesPortal({ ))}
+ + )} + + {/* Private Resources (Site Resources) Section */} + {paginatedSiteResources.length > 0 && ( + <> +
+

+ + Private Resources +

+

+ Internal network resources accessible via client +

+
+
+ {paginatedSiteResources.map((siteResource) => ( + +
+
+
+ + + + + {siteResource.name} + + + +

+ {siteResource.name} +

+
+
+
+
+ +
+ +
+
Resource Details
+
+ Mode: + + {siteResource.mode} + +
+ {siteResource.protocol && ( +
+ Protocol: + + {siteResource.protocol} + +
+ )} + {siteResource.alias && ( +
+ Alias: + + {siteResource.alias} + +
+ )} + {siteResource.aliasAddress && ( +
+ Alias Address: + + {siteResource.aliasAddress} + +
+ )} +
+ Status: + + {siteResource.enabled ? 'Enabled' : 'Disabled'} + +
+
+
+
+
+ +
+ {siteResource.alias ? ( + <> + {/* Alias as primary */} +
+
+ {siteResource.alias} +
+ +
+ {/* Destination as secondary */} +
+ {siteResource.destination} +
+ + ) : ( + /* Destination as primary when no alias */ +
+
+ {siteResource.destination} +
+ +
+ )} +
+
+ +
+
+ + Requires Client Connection +
+
+
+ ))} +
+ + )} {/* Pagination Controls */} diff --git a/src/components/OrgSelector.tsx b/src/components/OrgSelector.tsx index e139e43a..cacaf553 100644 --- a/src/components/OrgSelector.tsx +++ b/src/components/OrgSelector.tsx @@ -20,12 +20,13 @@ import { TooltipProvider, TooltipTrigger } from "@app/components/ui/tooltip"; +import { Badge } from "@app/components/ui/badge"; import { useEnvContext } from "@app/hooks/useEnvContext"; import { cn } from "@app/lib/cn"; import { ListUserOrgsResponse } from "@server/routers/org"; import { Check, ChevronsUpDown, Plus, Building2, Users } from "lucide-react"; -import { useRouter } from "next/navigation"; -import { useState } from "react"; +import { usePathname, useRouter } from "next/navigation"; +import { useMemo, useState } from "react"; import { useUserContext } from "@app/hooks/useUserContext"; import { useTranslations } from "next-intl"; @@ -43,11 +44,23 @@ export function OrgSelector({ const { user } = useUserContext(); const [open, setOpen] = useState(false); const router = useRouter(); + const pathname = usePathname(); const { env } = useEnvContext(); const t = useTranslations(); const selectedOrg = orgs?.find((org) => org.orgId === orgId); + const sortedOrgs = useMemo(() => { + if (!orgs?.length) return orgs ?? []; + return [...orgs].sort((a, b) => { + const aPrimary = Boolean(a.isPrimaryOrg); + const bPrimary = Boolean(b.isPrimaryOrg); + if (aPrimary && !bPrimary) return -1; + if (!aPrimary && bPrimary) return 1; + return 0; + }); + }, [orgs]); + const orgSelectorContent = ( @@ -124,25 +137,39 @@ export function OrgSelector({ )} - {orgs?.map((org) => ( + {sortedOrgs.map((org) => ( { setOpen(false); - router.push(`/${org.orgId}/settings`); + const newPath = pathname.replace( + /^\/[^/]+/, + `/${org.orgId}` + ); + router.push(newPath); }} className="mx-2 rounded-md" >
-
- +
+ {org.name} - - {t("organization")} - +
+ + {org.orgId} + + {org.isPrimaryOrg && ( + + {t("primary")} + + )} +
{ console.error(e);