Merge branch 'cloud-multi-org' into dev

This commit is contained in:
Owen
2026-02-17 21:01:44 -08:00
43 changed files with 1670 additions and 1101 deletions

View File

@@ -1031,6 +1031,7 @@
"pangolinSetup": "Setup - Pangolin",
"orgNameRequired": "Organization name is required",
"orgIdRequired": "Organization ID is required",
"orgIdMaxLength": "Organization ID must be at most 32 characters",
"orgErrorCreate": "An error occurred while creating org",
"pageNotFound": "Page Not Found",
"pageNotFoundDescription": "Oops! The page you're looking for doesn't exist.",
@@ -1266,6 +1267,7 @@
"sidebarLogAndAnalytics": "Log & Analytics",
"sidebarBluePrints": "Blueprints",
"sidebarOrganization": "Organization",
"sidebarBillingAndLicenses": "Billing & Licenses",
"sidebarLogsAnalytics": "Analytics",
"blueprints": "Blueprints",
"blueprintsDescription": "Apply declarative configurations and view previous runs",
@@ -1427,6 +1429,7 @@
"billingSites": "Sites",
"billingUsers": "Users",
"billingDomains": "Domains",
"billingOrganizations": "Orgs",
"billingRemoteExitNodes": "Remote Nodes",
"billingNoLimitConfigured": "No limit configured",
"billingEstimatedPeriod": "Estimated Billing Period",
@@ -1469,6 +1472,7 @@
"failed": "Failed",
"createNewOrgDescription": "Create a new organization",
"organization": "Organization",
"primary": "Primary",
"port": "Port",
"securityKeyManage": "Manage Security Keys",
"securityKeyDescription": "Add or remove security keys for passwordless authentication",

View File

@@ -55,7 +55,9 @@ export const orgs = pgTable("orgs", {
.notNull()
.default(0),
sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format)
sshCaPublicKey: text("sshCaPublicKey") // SSH CA public key (OpenSSH format)
sshCaPublicKey: text("sshCaPublicKey"), // SSH CA public key (OpenSSH format)
isBillingOrg: boolean("isBillingOrg"),
billingOrgId: varchar("billingOrgId")
});
export const orgDomains = pgTable("orgDomains", {

View File

@@ -47,7 +47,9 @@ export const orgs = sqliteTable("orgs", {
.notNull()
.default(0),
sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format)
sshCaPublicKey: text("sshCaPublicKey") // SSH CA public key (OpenSSH format)
sshCaPublicKey: text("sshCaPublicKey"), // SSH CA public key (OpenSSH format)
isBillingOrg: integer("isBillingOrg", { mode: "boolean" }),
billingOrgId: text("billingOrgId")
});
export const userDomains = sqliteTable("userDomains", {

View File

@@ -4,6 +4,7 @@ export enum FeatureId {
EGRESS_DATA_MB = "egressDataMb",
DOMAINS = "domains",
REMOTE_EXIT_NODES = "remoteExitNodes",
ORGINIZATIONS = "organizations",
TIER1 = "tier1"
}
@@ -19,6 +20,8 @@ export async function getFeatureDisplayName(featureId: FeatureId): Promise<strin
return "Domains";
case FeatureId.REMOTE_EXIT_NODES:
return "Remote Exit Nodes";
case FeatureId.ORGINIZATIONS:
return "Organizations";
case FeatureId.TIER1:
return "Home Lab";
default:

View File

@@ -7,18 +7,12 @@ export type LimitSet = Partial<{
};
}>;
export const sandboxLimitSet: LimitSet = {
[FeatureId.USERS]: { value: 1, description: "Sandbox limit" },
[FeatureId.SITES]: { value: 1, description: "Sandbox limit" },
[FeatureId.DOMAINS]: { value: 0, description: "Sandbox limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 0, description: "Sandbox limit" },
};
export const freeLimitSet: LimitSet = {
[FeatureId.SITES]: { value: 5, description: "Basic limit" },
[FeatureId.USERS]: { value: 5, description: "Basic limit" },
[FeatureId.DOMAINS]: { value: 5, description: "Basic limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Basic limit" },
[FeatureId.ORGINIZATIONS]: { value: 1, description: "Basic limit" },
};
export const tier1LimitSet: LimitSet = {
@@ -26,6 +20,7 @@ export const tier1LimitSet: LimitSet = {
[FeatureId.SITES]: { value: 10, description: "Home limit" },
[FeatureId.DOMAINS]: { value: 10, description: "Home limit" },
[FeatureId.REMOTE_EXIT_NODES]: { value: 1, description: "Home limit" },
[FeatureId.ORGINIZATIONS]: { value: 1, description: "Home limit" },
};
export const tier2LimitSet: LimitSet = {
@@ -45,6 +40,10 @@ export const tier2LimitSet: LimitSet = {
value: 3,
description: "Team limit"
},
[FeatureId.ORGINIZATIONS]: {
value: 1,
description: "Team limit"
}
};
export const tier3LimitSet: LimitSet = {
@@ -64,4 +63,8 @@ export const tier3LimitSet: LimitSet = {
value: 20,
description: "Business limit"
},
[FeatureId.ORGINIZATIONS]: {
value: 5,
description: "Business limit"
},
};

View File

@@ -1,34 +1,19 @@
import { eq, sql, and } from "drizzle-orm";
import { v4 as uuidv4 } from "uuid";
import { PutObjectCommand } from "@aws-sdk/client-s3";
import {
db,
usage,
customers,
sites,
newts,
limits,
Usage,
Limit,
Transaction
Transaction,
orgs
} from "@server/db";
import { FeatureId, getFeatureMeterId } from "./features";
import logger from "@server/logger";
import { sendToClient } from "#dynamic/routers/ws";
import { build } from "@server/build";
import { s3Client } from "@server/lib/s3";
import cache from "@server/lib/cache";
interface StripeEvent {
identifier?: string;
timestamp: number;
event_name: string;
payload: {
value: number;
stripe_customer_id: string;
};
}
export function noop() {
if (build !== "saas") {
return true;
@@ -37,41 +22,11 @@ export function noop() {
}
export class UsageService {
private bucketName: string | undefined;
private events: StripeEvent[] = [];
private lastUploadTime: number = Date.now();
private isUploading: boolean = false;
constructor() {
if (noop()) {
return;
}
// this.bucketName = process.env.S3_BUCKET || undefined;
// // Periodically check and upload events
// setInterval(() => {
// this.checkAndUploadEvents().catch((err) => {
// logger.error("Error in periodic event upload:", err);
// });
// }, 30000); // every 30 seconds
// // Handle graceful shutdown on SIGTERM
// process.on("SIGTERM", async () => {
// logger.info(
// "SIGTERM received, uploading events before shutdown..."
// );
// await this.forceUpload();
// logger.info("Events uploaded, proceeding with shutdown");
// });
// // Handle SIGINT as well (Ctrl+C)
// process.on("SIGINT", async () => {
// logger.info("SIGINT received, uploading events before shutdown...");
// await this.forceUpload();
// logger.info("Events uploaded, proceeding with shutdown");
// process.exit(0);
// });
}
/**
@@ -91,6 +46,8 @@ export class UsageService {
return null;
}
let orgIdToUse = await this.getBillingOrg(orgId, transaction);
// Truncate value to 11 decimal places
value = this.truncateValue(value);
@@ -100,20 +57,10 @@ export class UsageService {
while (attempt <= maxRetries) {
try {
// Get subscription data for this org (with caching)
const customerId = await this.getCustomerId(orgId, featureId);
if (!customerId) {
logger.warn(
`No subscription data found for org ${orgId} and feature ${featureId}`
);
return null;
}
let usage;
if (transaction) {
usage = await this.internalAddUsage(
orgId,
orgIdToUse,
featureId,
value,
transaction
@@ -121,7 +68,7 @@ export class UsageService {
} else {
await db.transaction(async (trx) => {
usage = await this.internalAddUsage(
orgId,
orgIdToUse,
featureId,
value,
trx
@@ -129,11 +76,6 @@ export class UsageService {
});
}
// Log event for Stripe
// if (privateConfig.getRawPrivateConfig().flags.usage_reporting) {
// await this.logStripeEvent(featureId, value, customerId);
// }
return usage || null;
} catch (error: any) {
// Check if this is a deadlock error
@@ -150,7 +92,7 @@ export class UsageService {
const delay = baseDelay + jitter;
logger.warn(
`Deadlock detected for ${orgId}/${featureId}, retrying attempt ${attempt}/${maxRetries} after ${delay.toFixed(0)}ms`
`Deadlock detected for ${orgIdToUse}/${featureId}, retrying attempt ${attempt}/${maxRetries} after ${delay.toFixed(0)}ms`
);
await new Promise((resolve) => setTimeout(resolve, delay));
@@ -158,7 +100,7 @@ export class UsageService {
}
logger.error(
`Failed to add usage for ${orgId}/${featureId} after ${attempt} attempts:`,
`Failed to add usage for ${orgIdToUse}/${featureId} after ${attempt} attempts:`,
error
);
break;
@@ -169,7 +111,7 @@ export class UsageService {
}
private async internalAddUsage(
orgId: string,
orgId: string, // here the orgId is the billing org already resolved by getBillingOrg in updateCount
featureId: FeatureId,
value: number,
trx: Transaction
@@ -188,17 +130,22 @@ export class UsageService {
featureId,
orgId,
meterId,
latestValue: value,
instantaneousValue: value || 0,
latestValue: value || 0,
updatedAt: Math.floor(Date.now() / 1000)
})
.onConflictDoUpdate({
target: usage.usageId,
set: {
latestValue: sql`${usage.latestValue} + ${value}`
instantaneousValue: sql`COALESCE(${usage.instantaneousValue}, 0) + ${value}`
}
})
.returning();
logger.debug(
`Added usage for org ${orgId} feature ${featureId}: +${value}, new instantaneousValue: ${returnUsage.instantaneousValue}`
);
return returnUsage;
}
@@ -221,18 +168,10 @@ export class UsageService {
if (noop()) {
return;
}
try {
if (!customerId) {
customerId =
(await this.getCustomerId(orgId, featureId)) || undefined;
if (!customerId) {
logger.warn(
`No subscription data found for org ${orgId} and feature ${featureId}`
);
return;
}
}
let orgIdToUse = await this.getBillingOrg(orgId);
try {
// Truncate value to 11 decimal places if provided
if (value !== undefined && value !== null) {
value = this.truncateValue(value);
@@ -242,7 +181,7 @@ export class UsageService {
await db.transaction(async (trx) => {
// Get existing meter record
const usageId = `${orgId}-${featureId}`;
const usageId = `${orgIdToUse}-${featureId}`;
// Get current usage record
[currentUsage] = await trx
.select()
@@ -264,7 +203,7 @@ export class UsageService {
await trx.insert(usage).values({
usageId,
featureId,
orgId,
orgId: orgIdToUse,
meterId,
instantaneousValue: value || 0,
latestValue: value || 0,
@@ -278,7 +217,7 @@ export class UsageService {
// }
} catch (error) {
logger.error(
`Failed to update count usage for ${orgId}/${featureId}:`,
`Failed to update count usage for ${orgIdToUse}/${featureId}:`,
error
);
}
@@ -288,7 +227,9 @@ export class UsageService {
orgId: string,
featureId: FeatureId
): Promise<string | null> {
const cacheKey = `customer_${orgId}_${featureId}`;
let orgIdToUse = await this.getBillingOrg(orgId);
const cacheKey = `customer_${orgIdToUse}_${featureId}`;
const cached = cache.get<string>(cacheKey);
if (cached) {
@@ -302,7 +243,7 @@ export class UsageService {
customerId: customers.customerId
})
.from(customers)
.where(eq(customers.orgId, orgId))
.where(eq(customers.orgId, orgIdToUse))
.limit(1);
if (!customer) {
@@ -317,112 +258,13 @@ export class UsageService {
return customerId;
} catch (error) {
logger.error(
`Failed to get subscription data for ${orgId}/${featureId}:`,
`Failed to get subscription data for ${orgIdToUse}/${featureId}:`,
error
);
return null;
}
}
private async logStripeEvent(
featureId: FeatureId,
value: number,
customerId: string
): Promise<void> {
// Truncate value to 11 decimal places before sending to Stripe
const truncatedValue = this.truncateValue(value);
const event: StripeEvent = {
identifier: uuidv4(),
timestamp: Math.floor(new Date().getTime() / 1000),
event_name: featureId,
payload: {
value: truncatedValue,
stripe_customer_id: customerId
}
};
this.addEventToMemory(event);
await this.checkAndUploadEvents();
}
private addEventToMemory(event: StripeEvent): void {
if (!this.bucketName) {
logger.warn(
"S3 bucket name is not configured, skipping event storage."
);
return;
}
this.events.push(event);
}
private async checkAndUploadEvents(): Promise<void> {
const now = Date.now();
const timeSinceLastUpload = now - this.lastUploadTime;
// Check if at least 1 minute has passed since last upload
if (timeSinceLastUpload >= 60000 && this.events.length > 0) {
await this.uploadEventsToS3();
}
}
private async uploadEventsToS3(): Promise<void> {
if (!this.bucketName) {
logger.warn(
"S3 bucket name is not configured, skipping S3 upload."
);
return;
}
if (this.events.length === 0) {
return;
}
// Check if already uploading
if (this.isUploading) {
logger.debug("Already uploading events, skipping");
return;
}
this.isUploading = true;
try {
// Take a snapshot of current events and clear the array
const eventsToUpload = [...this.events];
this.events = [];
this.lastUploadTime = Date.now();
const fileName = this.generateEventFileName();
const fileContent = JSON.stringify(eventsToUpload, null, 2);
// Upload to S3
const uploadCommand = new PutObjectCommand({
Bucket: this.bucketName,
Key: fileName,
Body: fileContent,
ContentType: "application/json"
});
await s3Client.send(uploadCommand);
logger.info(
`Uploaded ${fileName} to S3 with ${eventsToUpload.length} events`
);
} catch (error) {
logger.error("Failed to upload events to S3:", error);
// Note: Events are lost if upload fails. In a production system,
// you might want to add the events back to the array or implement retry logic
} finally {
this.isUploading = false;
}
}
private generateEventFileName(): string {
const timestamp = new Date().toISOString().replace(/[:.]/g, "-");
const uuid = uuidv4().substring(0, 8);
return `events-${timestamp}-${uuid}.json`;
}
public async getUsage(
orgId: string,
featureId: FeatureId,
@@ -432,7 +274,9 @@ export class UsageService {
return null;
}
const usageId = `${orgId}-${featureId}`;
let orgIdToUse = await this.getBillingOrg(orgId, trx);
const usageId = `${orgIdToUse}-${featureId}`;
try {
const [result] = await trx
@@ -444,7 +288,7 @@ export class UsageService {
if (!result) {
// Lets create one if it doesn't exist using upsert to handle race conditions
logger.info(
`Creating new usage record for ${orgId}/${featureId}`
`Creating new usage record for ${orgIdToUse}/${featureId}`
);
const meterId = getFeatureMeterId(featureId);
@@ -454,7 +298,7 @@ export class UsageService {
.values({
usageId,
featureId,
orgId,
orgId: orgIdToUse,
meterId,
latestValue: 0,
updatedAt: Math.floor(Date.now() / 1000)
@@ -476,7 +320,7 @@ export class UsageService {
} catch (insertError) {
// Fallback: try to fetch existing record in case of any insert issues
logger.warn(
`Insert failed for ${orgId}/${featureId}, attempting to fetch existing record:`,
`Insert failed for ${orgIdToUse}/${featureId}, attempting to fetch existing record:`,
insertError
);
const [existingUsage] = await trx
@@ -491,19 +335,41 @@ export class UsageService {
return result;
} catch (error) {
logger.error(
`Failed to get usage for ${orgId}/${featureId}:`,
`Failed to get usage for ${orgIdToUse}/${featureId}:`,
error
);
throw error;
}
}
public async forceUpload(): Promise<void> {
if (this.events.length > 0) {
// Force upload regardless of time
this.lastUploadTime = 0; // Reset to force upload
await this.uploadEventsToS3();
public async getBillingOrg(
orgId: string,
trx: Transaction | typeof db = db
): Promise<string> {
let orgIdToUse = orgId;
// get the org
const [org] = await trx
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
throw new Error(`Organization with ID ${orgId} not found`);
}
if (!org.isBillingOrg) {
if (org.billingOrgId) {
orgIdToUse = org.billingOrgId;
} else {
throw new Error(
`Organization ${orgId} is not a billing org and does not have a billingOrgId set`
);
}
}
return orgIdToUse;
}
public async checkLimitSet(
@@ -515,6 +381,9 @@ export class UsageService {
if (noop()) {
return false;
}
let orgIdToUse = await this.getBillingOrg(orgId, trx);
// This method should check the current usage against the limits set for the organization
// and kick out all of the sites on the org
let hasExceededLimits = false;
@@ -528,7 +397,7 @@ export class UsageService {
.from(limits)
.where(
and(
eq(limits.orgId, orgId),
eq(limits.orgId, orgIdToUse),
eq(limits.featureId, featureId)
)
);
@@ -537,11 +406,11 @@ export class UsageService {
orgLimits = await trx
.select()
.from(limits)
.where(eq(limits.orgId, orgId));
.where(eq(limits.orgId, orgIdToUse));
}
if (orgLimits.length === 0) {
logger.debug(`No limits set for org ${orgId}`);
logger.debug(`No limits set for org ${orgIdToUse}`);
return false;
}
@@ -552,7 +421,7 @@ export class UsageService {
currentUsage = usage;
} else {
currentUsage = await this.getUsage(
orgId,
orgIdToUse,
limit.featureId as FeatureId,
trx
);
@@ -563,10 +432,10 @@ export class UsageService {
currentUsage?.latestValue ||
0;
logger.debug(
`Current usage for org ${orgId} on feature ${limit.featureId}: ${usageValue}`
`Current usage for org ${orgIdToUse} on feature ${limit.featureId}: ${usageValue}`
);
logger.debug(
`Limit for org ${orgId} on feature ${limit.featureId}: ${limit.value}`
`Limit for org ${orgIdToUse} on feature ${limit.featureId}: ${limit.value}`
);
if (
currentUsage &&
@@ -574,7 +443,7 @@ export class UsageService {
usageValue > limit.value
) {
logger.debug(
`Org ${orgId} has exceeded limit for ${limit.featureId}: ` +
`Org ${orgIdToUse} has exceeded limit for ${limit.featureId}: ` +
`${usageValue} > ${limit.value}`
);
hasExceededLimits = true;
@@ -582,7 +451,7 @@ export class UsageService {
}
}
} catch (error) {
logger.error(`Error checking limits for org ${orgId}:`, error);
logger.error(`Error checking limits for org ${orgIdToUse}:`, error);
}
return hasExceededLimits;

View File

@@ -1,206 +0,0 @@
import { isValidCIDR } from "@server/lib/validators";
import { getNextAvailableOrgSubnet } from "@server/lib/ip";
import {
actions,
apiKeyOrg,
apiKeys,
db,
domains,
Org,
orgDomains,
orgs,
roleActions,
roles,
userOrgs
} from "@server/db";
import { eq } from "drizzle-orm";
import { defaultRoleAllowedActions } from "@server/routers/role";
import { FeatureId, limitsService, sandboxLimitSet } from "@server/lib/billing";
import { createCustomer } from "#dynamic/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import config from "@server/lib/config";
import { generateCA } from "@server/private/lib/sshCA";
import { encrypt } from "@server/lib/crypto";
export async function createUserAccountOrg(
userId: string,
userEmail: string
): Promise<{
success: boolean;
org?: {
orgId: string;
name: string;
subnet: string;
};
error?: string;
}> {
// const subnet = await getNextAvailableOrgSubnet();
const orgId = "org_" + userId;
const name = `${userEmail}'s Organization`;
// if (!isValidCIDR(subnet)) {
// return {
// success: false,
// error: "Invalid subnet format. Please provide a valid CIDR notation."
// };
// }
// // make sure the subnet is unique
// const subnetExists = await db
// .select()
// .from(orgs)
// .where(eq(orgs.subnet, subnet))
// .limit(1);
// if (subnetExists.length > 0) {
// return { success: false, error: `Subnet ${subnet} already exists` };
// }
// make sure the orgId is unique
const orgExists = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (orgExists.length > 0) {
return {
success: false,
error: `Organization with ID ${orgId} already exists`
};
}
let error = "";
let org: Org | null = null;
await db.transaction(async (trx) => {
const allDomains = await trx
.select()
.from(domains)
.where(eq(domains.configManaged, true));
const utilitySubnet = config.getRawConfig().orgs.utility_subnet_group;
// Generate SSH CA keys for the org
// const ca = generateCA(`${orgId}-ca`);
// const encryptionKey = config.getRawConfig().server.secret!;
// const encryptedCaPrivateKey = encrypt(ca.privateKeyPem, encryptionKey);
const newOrg = await trx
.insert(orgs)
.values({
orgId,
name,
// subnet
subnet: "100.90.128.0/24", // TODO: this should not be hardcoded - or can it be the same in all orgs?
utilitySubnet: utilitySubnet,
createdAt: new Date().toISOString(),
// sshCaPrivateKey: encryptedCaPrivateKey,
// sshCaPublicKey: ca.publicKeyOpenSSH
})
.returning();
if (newOrg.length === 0) {
error = "Failed to create organization";
trx.rollback();
return;
}
org = newOrg[0];
// Create admin role within the same transaction
const [insertedRole] = await trx
.insert(roles)
.values({
orgId: newOrg[0].orgId,
isAdmin: true,
name: "Admin",
description: "Admin role with the most permissions"
})
.returning({ roleId: roles.roleId });
if (!insertedRole || !insertedRole.roleId) {
error = "Failed to create Admin role";
trx.rollback();
return;
}
const roleId = insertedRole.roleId;
// Get all actions and create role actions
const actionIds = await trx.select().from(actions).execute();
if (actionIds.length > 0) {
await trx.insert(roleActions).values(
actionIds.map((action) => ({
roleId,
actionId: action.actionId,
orgId: newOrg[0].orgId
}))
);
}
if (allDomains.length) {
await trx.insert(orgDomains).values(
allDomains.map((domain) => ({
orgId: newOrg[0].orgId,
domainId: domain.domainId
}))
);
}
await trx.insert(userOrgs).values({
userId,
orgId: newOrg[0].orgId,
roleId: roleId,
isOwner: true
});
const memberRole = await trx
.insert(roles)
.values({
name: "Member",
description: "Members can only view resources",
orgId
})
.returning();
await trx.insert(roleActions).values(
defaultRoleAllowedActions.map((action) => ({
roleId: memberRole[0].roleId,
actionId: action,
orgId
}))
);
});
await limitsService.applyLimitSetToOrg(orgId, sandboxLimitSet);
if (!org) {
return { success: false, error: "Failed to create org" };
}
if (error) {
return {
success: false,
error: `Failed to create org: ${error}`
};
}
// make sure we have the stripe customer
const customerId = await createCustomer(orgId, userEmail);
if (customerId) {
await usageService.updateCount(orgId, FeatureId.USERS, 1, customerId); // Only 1 because we are crating the org
}
return {
org: {
orgId,
name,
// subnet
subnet: "100.90.128.0/24"
},
success: true
};
}

View File

@@ -4,14 +4,18 @@ import {
clientSitesAssociationsCache,
db,
domains,
exitNodeOrgs,
exitNodes,
olms,
orgDomains,
orgs,
remoteExitNodes,
resources,
sites
sites,
userOrgs
} from "@server/db";
import { newts, newtSessions } from "@server/db";
import { eq, and, inArray, sql } from "drizzle-orm";
import { eq, and, inArray, sql, count, countDistinct } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
@@ -19,6 +23,8 @@ import { sendToClient } from "#dynamic/routers/ws";
import { deletePeer } from "@server/routers/gerbil/peers";
import { OlmErrorCodes } from "@server/routers/olm/error";
import { sendTerminateClient } from "@server/routers/client/terminate";
import { usageService } from "./billing/usageService";
import { FeatureId } from "./billing";
export type DeleteOrgByIdResult = {
deletedNewtIds: string[];
@@ -60,6 +66,11 @@ export async function deleteOrgById(
const deletedNewtIds: string[] = [];
const olmsToTerminate: string[] = [];
let domainCount: number | null = null;
let siteCount: number | null = null;
let userCount: number | null = null;
let remoteExitNodeCount: number | null = null;
await db.transaction(async (trx) => {
for (const site of orgSites) {
if (site.pubKey) {
@@ -74,9 +85,7 @@ export async function deleteOrgById(
deletedNewtIds.push(deletedNewt.newtId);
await trx
.delete(newtSessions)
.where(
eq(newtSessions.newtId, deletedNewt.newtId)
);
.where(eq(newtSessions.newtId, deletedNewt.newtId));
}
}
}
@@ -137,9 +146,74 @@ export async function deleteOrgById(
.where(inArray(domains.domainId, domainIdsToDelete));
}
await trx.delete(resources).where(eq(resources.orgId, orgId));
await usageService.add(orgId, FeatureId.ORGINIZATIONS, -1, trx); // here we are decreasing the org count BEFORE deleting the org because we need to still be able to get the org to get the billing org inside of here
await trx.delete(orgs).where(eq(orgs.orgId, orgId));
if (org.billingOrgId) {
const billingOrgs = await trx
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, org.billingOrgId));
if (billingOrgs.length > 0) {
const billingOrgIds = billingOrgs.map((org) => org.orgId);
const [domainCountRes] = await trx
.select({ count: count() })
.from(orgDomains)
.where(inArray(orgDomains.orgId, billingOrgIds));
domainCount = domainCountRes.count;
const [siteCountRes] = await trx
.select({ count: count() })
.from(sites)
.where(inArray(sites.orgId, billingOrgIds));
siteCount = siteCountRes.count;
const [userCountRes] = await trx
.select({ count: countDistinct(userOrgs.userId) })
.from(userOrgs)
.where(inArray(userOrgs.orgId, billingOrgIds));
userCount = userCountRes.count;
const [remoteExitNodeCountRes] = await trx
.select({ count: countDistinct(exitNodeOrgs.exitNodeId) })
.from(exitNodeOrgs)
.where(inArray(exitNodeOrgs.orgId, billingOrgIds));
remoteExitNodeCount = remoteExitNodeCountRes.count;
}
}
});
if (org.billingOrgId) {
usageService.updateCount(
org.billingOrgId,
FeatureId.DOMAINS,
domainCount ?? 0
);
usageService.updateCount(
org.billingOrgId,
FeatureId.SITES,
siteCount ?? 0
);
usageService.updateCount(
org.billingOrgId,
FeatureId.USERS,
userCount ?? 0
);
usageService.updateCount(
org.billingOrgId,
FeatureId.REMOTE_EXIT_NODES,
remoteExitNodeCount ?? 0
);
}
return { deletedNewtIds, olmsToTerminate };
}
@@ -155,15 +229,13 @@ export function sendTerminationMessages(result: DeleteOrgByIdResult): void {
);
}
for (const olmId of result.olmsToTerminate) {
sendTerminateClient(
0,
OlmErrorCodes.TERMINATED_REKEYED,
olmId
).catch((error) => {
logger.error(
"Failed to send termination message to olm:",
error
);
});
sendTerminateClient(0, OlmErrorCodes.TERMINATED_REKEYED, olmId).catch(
(error) => {
logger.error(
"Failed to send termination message to olm:",
error
);
}
);
}
}

142
server/lib/userOrg.ts Normal file
View File

@@ -0,0 +1,142 @@
import {
db,
Org,
orgs,
resources,
siteResources,
sites,
Transaction,
UserOrg,
userOrgs,
userResources,
userSiteResources,
userSites
} from "@server/db";
import { eq, and, inArray, ne, exists } from "drizzle-orm";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
export async function assignUserToOrg(
org: Org,
values: typeof userOrgs.$inferInsert,
trx: Transaction | typeof db = db
) {
const [userOrg] = await trx.insert(userOrgs).values(values).returning();
// calculate if the user is in any other of the orgs before we count it as an add to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(
and(
eq(orgs.billingOrgId, org.billingOrgId),
ne(orgs.orgId, org.orgId)
)
);
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheUserIsStillIn = await trx
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userOrg.userId),
inArray(userOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) {
await usageService.add(org.orgId, FeatureId.USERS, 1, trx);
}
}
}
export async function removeUserFromOrg(
org: Org,
userId: string,
trx: Transaction | typeof db = db
) {
await trx
.delete(userOrgs)
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, org.orgId)));
await trx.delete(userResources).where(
and(
eq(userResources.userId, userId),
exists(
trx
.select()
.from(resources)
.where(
and(
eq(resources.resourceId, userResources.resourceId),
eq(resources.orgId, org.orgId)
)
)
)
)
);
await trx.delete(userSiteResources).where(
and(
eq(userSiteResources.userId, userId),
exists(
trx
.select()
.from(siteResources)
.where(
and(
eq(
siteResources.siteResourceId,
userSiteResources.siteResourceId
),
eq(siteResources.orgId, org.orgId)
)
)
)
)
);
await trx.delete(userSites).where(
and(
eq(userSites.userId, userId),
exists(
db
.select()
.from(sites)
.where(
and(
eq(sites.siteId, userSites.siteId),
eq(sites.orgId, org.orgId)
)
)
)
)
);
// calculate if the user is in any other of the orgs before we count it as an remove to the billing org
if (org.billingOrgId) {
const billingOrgs = await trx
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, org.billingOrgId));
const billingOrgIds = billingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheUserIsStillIn = await trx
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, userId),
inArray(userOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheUserIsStillIn.length === 0) {
await usageService.add(org.orgId, FeatureId.USERS, -1, trx);
}
}
}

View File

@@ -12,7 +12,8 @@
*/
import { build } from "@server/build";
import { db, customers, subscriptions } from "@server/db";
import { db, customers, subscriptions, orgs } from "@server/db";
import logger from "@server/logger";
import { Tier } from "@server/types/Tiers";
import { eq, and, ne } from "drizzle-orm";
@@ -27,37 +28,60 @@ export async function getOrgTierData(
}
try {
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return { tier, active };
}
let orgIdToUse = org.orgId;
if (!org.isBillingOrg) {
if (!org.billingOrgId) {
logger.warn(
`Org ${orgId} is not a billing org and does not have a billingOrgId`
);
return { tier, active };
}
orgIdToUse = org.billingOrgId;
}
// Get customer for org
const [customer] = await db
.select()
.from(customers)
.where(eq(customers.orgId, orgId))
.where(eq(customers.orgId, orgIdToUse))
.limit(1);
if (customer) {
// Query for active subscriptions that are not license type
const [subscription] = await db
.select()
.from(subscriptions)
.where(
and(
eq(subscriptions.customerId, customer.customerId),
eq(subscriptions.status, "active"),
ne(subscriptions.type, "license")
)
)
.limit(1);
if (!customer) {
return { tier, active };
}
if (subscription) {
// Validate that subscription.type is one of the expected tier values
if (
subscription.type === "tier1" ||
subscription.type === "tier2" ||
subscription.type === "tier3"
) {
tier = subscription.type;
active = true;
}
// Query for active subscriptions that are not license type
const [subscription] = await db
.select()
.from(subscriptions)
.where(
and(
eq(subscriptions.customerId, customer.customerId),
eq(subscriptions.status, "active"),
ne(subscriptions.type, "license")
)
)
.limit(1);
if (subscription) {
// Validate that subscription.type is one of the expected tier values
if (
subscription.type === "tier1" ||
subscription.type === "tier2" ||
subscription.type === "tier3"
) {
tier = subscription.type;
active = true;
}
}
} catch (error) {

View File

@@ -15,7 +15,18 @@ import { SubscriptionType } from "./hooks/getSubType";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { Tier } from "@server/types/Tiers";
import logger from "@server/logger";
import { db, idp, idpOrg, loginPage, loginPageBranding, loginPageBrandingOrg, loginPageOrg, orgs, resources, roles } from "@server/db";
import {
db,
idp,
idpOrg,
loginPage,
loginPageBranding,
loginPageBrandingOrg,
loginPageOrg,
orgs,
resources,
roles
} from "@server/db";
import { eq } from "drizzle-orm";
/**
@@ -59,10 +70,7 @@ async function capRetentionDays(
}
// Get current org settings
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId));
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
logger.warn(`Org ${orgId} not found when capping retention days`);
@@ -110,18 +118,13 @@ async function capRetentionDays(
// Apply updates if needed
if (needsUpdate) {
await db
.update(orgs)
.set(updates)
.where(eq(orgs.orgId, orgId));
await db.update(orgs).set(updates).where(eq(orgs.orgId, orgId));
logger.info(
`Successfully capped retention days for org ${orgId} to max ${maxRetentionDays} days`
);
} else {
logger.debug(
`No retention day capping needed for org ${orgId}`
);
logger.debug(`No retention day capping needed for org ${orgId}`);
}
}
@@ -134,6 +137,35 @@ export async function handleTierChange(
`Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}`
);
// Get all orgs that have this orgId as their billingOrgId
const associatedOrgs = await db
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, orgId));
logger.info(
`Found ${associatedOrgs.length} org(s) associated with billing org ${orgId}`
);
// Loop over all associated orgs and apply tier changes
for (const org of associatedOrgs) {
await handleTierChangeForOrg(org.orgId, newTier, previousTier);
}
logger.info(
`Completed tier change handling for all orgs associated with billing org ${orgId}`
);
}
async function handleTierChangeForOrg(
orgId: string,
newTier: SubscriptionType | null,
previousTier?: SubscriptionType | null
): Promise<void> {
logger.info(
`Handling tier change for org ${orgId}: ${previousTier || "none"} -> ${newTier || "free"}`
);
// License subscriptions are handled separately and don't use the tier matrix
if (newTier === "license") {
logger.debug(
@@ -314,9 +346,7 @@ async function disableLoginPageDomain(orgId: string): Promise<void> {
);
if (existingLoginPage) {
await db
.delete(loginPageOrg)
.where(eq(loginPageOrg.orgId, orgId));
await db.delete(loginPageOrg).where(eq(loginPageOrg.orgId, orgId));
await db
.delete(loginPage)

View File

@@ -112,11 +112,13 @@ export async function getOrgSubscriptionsData(
throw new Error(`Not found`);
}
const billingOrgId = org[0].billingOrgId || org[0].orgId;
// Get customer for org
const customer = await db
.select()
.from(customers)
.where(eq(customers.orgId, orgId))
.where(eq(customers.orgId, billingOrgId))
.limit(1);
const subscriptionsWithItems: Array<{

View File

@@ -85,10 +85,14 @@ export async function getOrgUsage(
orgId,
FeatureId.REMOTE_EXIT_NODES
);
const egressData = await usageService.getUsage(
const organizations = await usageService.getUsage(
orgId,
FeatureId.EGRESS_DATA_MB
FeatureId.ORGINIZATIONS
);
// const egressData = await usageService.getUsage(
// orgId,
// FeatureId.EGRESS_DATA_MB
// );
if (sites) {
usageData.push(sites);
@@ -96,15 +100,18 @@ export async function getOrgUsage(
if (users) {
usageData.push(users);
}
if (egressData) {
usageData.push(egressData);
}
// if (egressData) {
// usageData.push(egressData);
// }
if (domains) {
usageData.push(domains);
}
if (remoteExitNodes) {
usageData.push(remoteExitNodes);
}
if (organizations) {
usageData.push(organizations);
}
const orgLimits = await db
.select()

View File

@@ -12,7 +12,14 @@
*/
import { NextFunction, Request, Response } from "express";
import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg } from "@server/db";
import {
db,
exitNodes,
exitNodeOrgs,
ExitNode,
ExitNodeOrg,
orgs
} from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { remoteExitNodes } from "@server/db";
@@ -25,7 +32,7 @@ import { createRemoteExitNodeSession } from "#private/auth/sessions/remoteExitNo
import { fromError } from "zod-validation-error";
import { hashPassword, verifyPassword } from "@server/auth/password";
import logger from "@server/logger";
import { and, eq } from "drizzle-orm";
import { and, eq, inArray, ne } from "drizzle-orm";
import { getNextAvailableSubnet } from "@server/lib/exitNodes";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
@@ -169,7 +176,17 @@ export async function createRemoteExitNode(
);
}
let numExitNodeOrgs: ExitNodeOrg[] | undefined;
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
);
}
await db.transaction(async (trx) => {
if (!existingExitNode) {
@@ -217,19 +234,43 @@ export async function createRemoteExitNode(
});
}
numExitNodeOrgs = await trx
.select()
.from(exitNodeOrgs)
.where(eq(exitNodeOrgs.orgId, orgId));
});
// calculate if the node is in any other of the orgs before we count it as an add to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(
and(
eq(orgs.billingOrgId, org.billingOrgId),
ne(orgs.orgId, orgId)
)
);
if (numExitNodeOrgs) {
await usageService.updateCount(
orgId,
FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length
);
}
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheNodeIsStillIn = await trx
.select()
.from(exitNodeOrgs)
.where(
and(
eq(
exitNodeOrgs.exitNodeId,
existingExitNode.exitNodeId
),
inArray(exitNodeOrgs.orgId, billingOrgIds)
)
);
if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) {
await usageService.add(
orgId,
FeatureId.REMOTE_EXIT_NODES,
1,
trx
);
}
}
});
const token = generateSessionToken();
await createRemoteExitNodeSession(token, remoteExitNodeId);

View File

@@ -13,9 +13,9 @@
import { NextFunction, Request, Response } from "express";
import { z } from "zod";
import { db, ExitNodeOrg, exitNodeOrgs, exitNodes } from "@server/db";
import { db, ExitNodeOrg, exitNodeOrgs, exitNodes, orgs } from "@server/db";
import { remoteExitNodes } from "@server/db";
import { and, count, eq } from "drizzle-orm";
import { and, count, eq, inArray } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -50,7 +50,8 @@ export async function deleteRemoteExitNode(
const [remoteExitNode] = await db
.select()
.from(remoteExitNodes)
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId));
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId))
.limit(1);
if (!remoteExitNode) {
return next(
@@ -70,7 +71,17 @@ export async function deleteRemoteExitNode(
);
}
let numExitNodeOrgs: ExitNodeOrg[] | undefined;
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Org with ID ${orgId} not found`
)
);
}
await db.transaction(async (trx) => {
await trx
.delete(exitNodeOrgs)
@@ -81,38 +92,39 @@ export async function deleteRemoteExitNode(
)
);
const [remainingExitNodeOrgs] = await trx
.select({ count: count() })
.from(exitNodeOrgs)
.where(eq(exitNodeOrgs.exitNodeId, remoteExitNode.exitNodeId!));
// calculate if the user is in any other of the orgs before we count it as an remove to the billing org
if (org.billingOrgId) {
const otherBillingOrgs = await trx
.select()
.from(orgs)
.where(eq(orgs.billingOrgId, org.billingOrgId));
if (remainingExitNodeOrgs.count === 0) {
await trx
.delete(remoteExitNodes)
const billingOrgIds = otherBillingOrgs.map((o) => o.orgId);
const orgsInBillingDomainThatTheNodeIsStillIn = await trx
.select()
.from(exitNodeOrgs)
.where(
eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId)
and(
eq(
exitNodeOrgs.exitNodeId,
remoteExitNode.exitNodeId!
),
inArray(exitNodeOrgs.orgId, billingOrgIds)
)
);
await trx
.delete(exitNodes)
.where(
eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId!)
if (orgsInBillingDomainThatTheNodeIsStillIn.length === 0) {
await usageService.add(
orgId,
FeatureId.REMOTE_EXIT_NODES,
-1,
trx
);
}
}
numExitNodeOrgs = await trx
.select()
.from(exitNodeOrgs)
.where(eq(exitNodeOrgs.orgId, orgId));
});
if (numExitNodeOrgs) {
await usageService.updateCount(
orgId,
FeatureId.REMOTE_EXIT_NODES,
numExitNodeOrgs.length
);
}
return response(res, {
data: null,
success: true,

View File

@@ -15,11 +15,10 @@ import {
import { verifyPassword } from "@server/auth/password";
import { verifyTotpCode } from "@server/auth/totp";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import {
deleteOrgById,
sendTerminationMessages
} from "@server/lib/deleteOrg";
import { deleteOrgById, sendTerminationMessages } from "@server/lib/deleteOrg";
import { UserType } from "@server/types/UserTypes";
import { build } from "@server/build";
import { getOrgTierData } from "#dynamic/lib/billing";
const deleteMyAccountBody = z.strictObject({
password: z.string().optional(),
@@ -40,11 +39,6 @@ export type DeleteMyAccountSuccessResponse = {
success: true;
};
/**
* Self-service account deletion (saas only). Returns preview when no password;
* requires password and optional 2FA code to perform deletion. Uses shared
* deleteOrgById for each owned org (delete-my-account may delete multiple orgs).
*/
export async function deleteMyAccount(
req: Request,
res: Response,
@@ -91,18 +85,35 @@ export async function deleteMyAccount(
const ownedOrgsRows = await db
.select({
orgId: userOrgs.orgId
orgId: userOrgs.orgId,
isOwner: userOrgs.isOwner,
isBillingOrg: orgs.isBillingOrg
})
.from(userOrgs)
.innerJoin(orgs, eq(userOrgs.orgId, orgs.orgId))
.where(
and(
eq(userOrgs.userId, userId),
eq(userOrgs.isOwner, true)
)
and(eq(userOrgs.userId, userId), eq(userOrgs.isOwner, true))
);
const orgIds = ownedOrgsRows.map((r) => r.orgId);
if (build === "saas" && orgIds.length > 0) {
const primaryOrgId = ownedOrgsRows.find(
(r) => r.isBillingOrg && r.isOwner
)?.orgId;
if (primaryOrgId) {
const { tier, active } = await getOrgTierData(primaryOrgId);
if (active && tier) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"You must cancel your subscription before deleting your account"
)
);
}
}
}
if (!password) {
const orgsWithNames =
orgIds.length > 0
@@ -219,10 +230,7 @@ export async function deleteMyAccount(
} catch (error) {
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An error occurred"
)
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -1,7 +1,7 @@
import { NextFunction, Request, Response } from "express";
import { db, users } from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { email, z } from "zod";
import { fromError } from "zod-validation-error";
import createHttpError from "http-errors";
import response from "@server/lib/response";
@@ -21,7 +21,6 @@ import { hashPassword } from "@server/auth/password";
import { checkValidInvite } from "@server/auth/checkValidInvite";
import { passwordSchema } from "@server/auth/passwordSchema";
import { UserType } from "@server/types/UserTypes";
import { createUserAccountOrg } from "@server/lib/createUserAccountOrg";
import { build } from "@server/build";
import resend, { AudienceIds, moveEmailToAudience } from "#dynamic/lib/resend";
@@ -31,7 +30,8 @@ export const signupBodySchema = z.object({
inviteToken: z.string().optional(),
inviteId: z.string().optional(),
termsAcceptedTimestamp: z.string().nullable().optional(),
marketingEmailConsent: z.boolean().optional()
marketingEmailConsent: z.boolean().optional(),
skipVerificationEmail: z.boolean().optional()
});
export type SignUpBody = z.infer<typeof signupBodySchema>;
@@ -62,7 +62,8 @@ export async function signup(
inviteToken,
inviteId,
termsAcceptedTimestamp,
marketingEmailConsent
marketingEmailConsent,
skipVerificationEmail
} = parsedBody.data;
const passwordHash = await hashPassword(password);
@@ -198,26 +199,6 @@ export async function signup(
// orgId: null,
// });
if (build == "saas") {
const { success, error, org } = await createUserAccountOrg(
userId,
email
);
if (!success) {
if (error) {
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, error)
);
}
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to create user account and organization"
)
);
}
}
const token = generateSessionToken();
const sess = await createSession(token, userId);
const isSecure = req.protocol === "https";
@@ -235,7 +216,13 @@ export async function signup(
}
if (config.getRawConfig().flags?.require_email_verification) {
sendEmailVerificationCode(email, userId);
if (!skipVerificationEmail) {
sendEmailVerificationCode(email, userId);
} else {
logger.debug(
`User ${email} opted out of verification email during signup.`
);
}
return response<SignUpResponse>(res, {
data: {
@@ -243,7 +230,9 @@ export async function signup(
},
success: true,
error: false,
message: `User created successfully. We sent an email to ${email} with a verification code.`,
message: skipVerificationEmail
? "User created successfully. Please verify your email."
: `User created successfully. We sent an email to ${email} with a verification code.`,
status: HttpCode.OK
});
}

View File

@@ -148,7 +148,6 @@ export async function createOrgDomain(
}
}
let numOrgDomains: OrgDomains[] | undefined;
let aRecords: CreateDomainResponse["aRecords"];
let cnameRecords: CreateDomainResponse["cnameRecords"];
let txtRecords: CreateDomainResponse["txtRecords"];
@@ -347,20 +346,9 @@ export async function createOrgDomain(
await trx.insert(dnsRecords).values(recordsToInsert);
}
numOrgDomains = await trx
.select()
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId));
await usageService.add(orgId, FeatureId.DOMAINS, 1, trx);
});
if (numOrgDomains) {
await usageService.updateCount(
orgId,
FeatureId.DOMAINS,
numOrgDomains.length
);
}
if (!returned) {
return next(
createHttpError(

View File

@@ -36,8 +36,6 @@ export async function deleteAccountDomain(
}
const { domainId, orgId } = parsed.data;
let numOrgDomains: OrgDomains[] | undefined;
await db.transaction(async (trx) => {
const [existing] = await trx
.select()
@@ -79,20 +77,9 @@ export async function deleteAccountDomain(
await trx.delete(domains).where(eq(domains.domainId, domainId));
numOrgDomains = await trx
.select()
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId));
await usageService.add(orgId, FeatureId.DOMAINS, -1, trx);
});
if (numOrgDomains) {
await usageService.updateCount(
orgId,
FeatureId.DOMAINS,
numOrgDomains.length
);
}
return response<DeleteAccountDomainResponse>(res, {
data: { success: true },
success: true,

View File

@@ -65,9 +65,8 @@ authenticated.use(verifySessionUserMiddleware);
authenticated.get("/pick-org-defaults", org.pickOrgDefaults);
authenticated.get("/org/checkId", org.checkId);
if (build === "oss" || build === "enterprise") {
authenticated.put("/org", getUserOrgs, org.createOrg);
}
authenticated.put("/org", getUserOrgs, org.createOrg);
authenticated.get("/orgs", verifyUserIsServerAdmin, org.listOrgs);
authenticated.get("/user/:userId/orgs", verifyIsLoggedInUser, org.listUserOrgs);
@@ -87,16 +86,14 @@ authenticated.post(
org.updateOrg
);
if (build !== "saas") {
authenticated.delete(
"/org/:orgId",
verifyOrgAccess,
verifyUserIsOrgOwner,
verifyUserHasAction(ActionsEnum.deleteOrg),
logActionAudit(ActionsEnum.deleteOrg),
org.deleteOrg
);
}
authenticated.delete(
"/org/:orgId",
verifyOrgAccess,
verifyUserIsOrgOwner,
verifyUserHasAction(ActionsEnum.deleteOrg),
logActionAudit(ActionsEnum.deleteOrg),
org.deleteOrg
);
authenticated.put(
"/org/:orgId/site",

View File

@@ -36,6 +36,10 @@ import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { isSubscribed } from "#dynamic/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import {
assignUserToOrg,
removeUserFromOrg
} from "@server/lib/userOrg";
const ensureTrailingSlash = (url: string): string => {
return url;
@@ -436,6 +440,7 @@ export async function validateOidcCallback(
}
}
// These are the orgs that the user should be provisioned into based on the IdP mappings and the token claims
logger.debug("User org info", { userOrgInfo });
let existingUserId = existingUser?.userId;
@@ -454,15 +459,32 @@ export async function validateOidcCallback(
);
if (!existingUserOrgs.length) {
// delete all auto -provisioned user orgs
await db
.delete(userOrgs)
// delete all auto-provisioned user orgs
const autoProvisionedUserOrgs = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, existingUser.userId),
eq(userOrgs.autoProvisioned, true)
)
);
const orgIdsToRemove = autoProvisionedUserOrgs.map(
(uo) => uo.orgId
);
if (orgIdsToRemove.length > 0) {
const orgsToRemove = await db
.select()
.from(orgs)
.where(inArray(orgs.orgId, orgIdsToRemove));
for (const org of orgsToRemove) {
await removeUserFromOrg(
org,
existingUser.userId,
db
);
}
}
await calculateUserClientsForOrgs(existingUser.userId);
@@ -484,7 +506,7 @@ export async function validateOidcCallback(
}
}
const orgUserCounts: { orgId: string; userCount: number }[] = [];
const orgUserCounts: { orgId: string; userCount: number }[] = [];
// sync the user with the orgs and roles
await db.transaction(async (trx) => {
@@ -538,15 +560,14 @@ export async function validateOidcCallback(
);
if (orgsToDelete.length > 0) {
await trx.delete(userOrgs).where(
and(
eq(userOrgs.userId, userId!),
inArray(
userOrgs.orgId,
orgsToDelete.map((org) => org.orgId)
)
)
);
const orgIdsToRemove = orgsToDelete.map((org) => org.orgId);
const fullOrgsToRemove = await trx
.select()
.from(orgs)
.where(inArray(orgs.orgId, orgIdsToRemove));
for (const org of fullOrgsToRemove) {
await removeUserFromOrg(org, userId!, trx);
}
}
// Update roles for existing auto-provisioned orgs where the role has changed
@@ -587,15 +608,24 @@ export async function validateOidcCallback(
);
if (orgsToAdd.length > 0) {
await trx.insert(userOrgs).values(
orgsToAdd.map((org) => ({
userId: userId!,
orgId: org.orgId,
roleId: org.roleId,
autoProvisioned: true,
dateCreated: new Date().toISOString()
}))
);
for (const org of orgsToAdd) {
const [fullOrg] = await trx
.select()
.from(orgs)
.where(eq(orgs.orgId, org.orgId));
if (fullOrg) {
await assignUserToOrg(
fullOrg,
{
orgId: org.orgId,
userId: userId!,
roleId: org.roleId,
autoProvisioned: true,
},
trx
);
}
}
}
// Loop through all the orgs and get the total number of users from the userOrgs table

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { eq } from "drizzle-orm";
import { and, count, eq } from "drizzle-orm";
import {
domains,
Org,
@@ -24,15 +24,24 @@ import { OpenAPITags, registry } from "@server/openApi";
import { isValidCIDR } from "@server/lib/validators";
import { createCustomer } from "#dynamic/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
import { FeatureId, limitsService, freeLimitSet } from "@server/lib/billing";
import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { doCidrsOverlap } from "@server/lib/ip";
import { generateCA } from "@server/private/lib/sshCA";
import { encrypt } from "@server/lib/crypto";
const validOrgIdRegex = /^[a-z0-9_]+(-[a-z0-9_]+)*$/;
const createOrgSchema = z.strictObject({
orgId: z.string(),
orgId: z
.string()
.min(1, "Organization ID is required")
.max(32, "Organization ID must be at most 32 characters")
.refine((val) => validOrgIdRegex.test(val), {
message:
"Organization ID must contain only lowercase letters, numbers, underscores, and single hyphens (no leading, trailing, or consecutive hyphens)"
}),
name: z.string().min(1).max(255),
subnet: z
// .union([z.cidrv4(), z.cidrv6()])
@@ -110,6 +119,7 @@ export async function createOrg(
// )
// );
// }
//
// make sure the orgId is unique
const orgExists = await db
@@ -136,8 +146,71 @@ export async function createOrg(
);
}
let isFirstOrg: boolean | null = null;
let billingOrgIdForNewOrg: string | null = null;
if (build === "saas" && req.user) {
const ownedOrgs = await db
.select()
.from(userOrgs)
.where(
and(
eq(userOrgs.userId, req.user.userId),
eq(userOrgs.isOwner, true)
)
);
if (ownedOrgs.length === 0) {
isFirstOrg = true;
} else {
isFirstOrg = false;
const [billingOrg] = await db
.select({ orgId: orgs.orgId })
.from(orgs)
.innerJoin(userOrgs, eq(orgs.orgId, userOrgs.orgId))
.where(
and(
eq(userOrgs.userId, req.user.userId),
eq(userOrgs.isOwner, true),
eq(orgs.isBillingOrg, true)
)
)
.limit(1);
if (billingOrg) {
billingOrgIdForNewOrg = billingOrg.orgId;
}
}
}
if (build == "saas" && billingOrgIdForNewOrg) {
const usage = await usageService.getUsage(billingOrgIdForNewOrg, FeatureId.ORGINIZATIONS);
if (!usage) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"No usage data found for this organization"
)
);
}
const rejectOrgs = await usageService.checkLimitSet(
billingOrgIdForNewOrg,
FeatureId.ORGINIZATIONS,
{
...usage,
instantaneousValue: (usage.instantaneousValue || 0) + 1
} // We need to add one to know if we are violating the limit
);
if (rejectOrgs) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Organization limit exceeded. Please upgrade your plan."
)
);
}
}
let error = "";
let org: Org | null = null;
let numOrgs: number | null = null;
await db.transaction(async (trx) => {
const allDomains = await trx
@@ -145,11 +218,21 @@ export async function createOrg(
.from(domains)
.where(eq(domains.configManaged, true));
// // Generate SSH CA keys for the org
// Generate SSH CA keys for the org
// const ca = generateCA(`${orgId}-ca`);
// const encryptionKey = config.getRawConfig().server.secret!;
// const encryptedCaPrivateKey = encrypt(ca.privateKeyPem, encryptionKey);
const saasBillingFields =
build === "saas" && req.user && isFirstOrg !== null
? isFirstOrg
? { isBillingOrg: true as const, billingOrgId: orgId } // if this is the first org, it becomes the billing org for itself
: {
isBillingOrg: false as const,
billingOrgId: billingOrgIdForNewOrg
}
: {};
const newOrg = await trx
.insert(orgs)
.values({
@@ -159,7 +242,8 @@ export async function createOrg(
utilitySubnet,
createdAt: new Date().toISOString(),
// sshCaPrivateKey: encryptedCaPrivateKey,
// sshCaPublicKey: ca.publicKeyOpenSSH
// sshCaPublicKey: ca.publicKeyOpenSSH,
...saasBillingFields
})
.returning();
@@ -261,6 +345,17 @@ export async function createOrg(
);
await calculateUserClientsForOrgs(ownerUserId, trx);
if (billingOrgIdForNewOrg) {
const [numOrgsResult] = await trx
.select({ count: count() })
.from(orgs)
.where(eq(orgs.billingOrgId, billingOrgIdForNewOrg)); // all the billable orgs including the primary org that is the billing org itself
numOrgs = numOrgsResult.count;
} else {
numOrgs = 1; // we only have one org if there is no billing org found out
}
});
if (!org) {
@@ -276,8 +371,8 @@ export async function createOrg(
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, error));
}
if (build == "saas") {
// make sure we have the stripe customer
if (build === "saas" && isFirstOrg === true) {
await limitsService.applyLimitSetToOrg(orgId, freeLimitSet);
const customerId = await createCustomer(orgId, req.user?.email);
if (customerId) {
await usageService.updateCount(
@@ -289,6 +384,14 @@ export async function createOrg(
}
}
if (numOrgs) {
usageService.updateCount(
billingOrgIdForNewOrg || orgId,
FeatureId.ORGINIZATIONS,
numOrgs
);
}
return response(res, {
data: org,
success: true,

View File

@@ -7,6 +7,8 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { deleteOrgById, sendTerminationMessages } from "@server/lib/deleteOrg";
import { db, userOrgs, orgs } from "@server/db";
import { eq, and } from "drizzle-orm";
const deleteOrgSchema = z.strictObject({
orgId: z.string()
@@ -41,6 +43,48 @@ export async function deleteOrg(
);
}
const { orgId } = parsedParams.data;
const [data] = await db
.select()
.from(userOrgs)
.innerJoin(orgs, eq(userOrgs.orgId, orgs.orgId))
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.userId, req.user!.userId)
)
);
const org = data?.orgs;
const userOrg = data?.userOrgs;
if (!org || !userOrg) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Organization with ID ${orgId} not found`
)
);
}
if (!userOrg.isOwner) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Only organization owners can delete the organization"
)
);
}
if (org.isBillingOrg) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Cannot delete a primary organization"
)
);
}
const result = await deleteOrgById(orgId);
sendTerminationMessages(result);
return response(res, {

View File

@@ -40,7 +40,11 @@ const listOrgsSchema = z.object({
// responses: {}
// });
type ResponseOrg = Org & { isOwner?: boolean; isAdmin?: boolean };
type ResponseOrg = Org & {
isOwner?: boolean;
isAdmin?: boolean;
isPrimaryOrg?: boolean;
};
export type ListUserOrgsResponse = {
orgs: ResponseOrg[];
@@ -132,6 +136,9 @@ export async function listUserOrgs(
if (val.roles && val.roles.isAdmin) {
res.isAdmin = val.roles.isAdmin;
}
if (val.userOrgs?.isOwner && val.orgs?.isBillingOrg) {
res.isPrimaryOrg = val.orgs.isBillingOrg;
}
return res;
});

View File

@@ -8,7 +8,10 @@ import {
userOrgs,
resourcePassword,
resourcePincode,
resourceWhitelist
resourceWhitelist,
siteResources,
userSiteResources,
roleSiteResources
} from "@server/db";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
@@ -57,9 +60,21 @@ export async function getUserResources(
.from(roleResources)
.where(eq(roleResources.roleId, userRoleId));
const [directResources, roleResourceResults] = await Promise.all([
const directSiteResourcesQuery = db
.select({ siteResourceId: userSiteResources.siteResourceId })
.from(userSiteResources)
.where(eq(userSiteResources.userId, userId));
const roleSiteResourcesQuery = db
.select({ siteResourceId: roleSiteResources.siteResourceId })
.from(roleSiteResources)
.where(eq(roleSiteResources.roleId, userRoleId));
const [directResources, roleResourceResults, directSiteResourceResults, roleSiteResourceResults] = await Promise.all([
directResourcesQuery,
roleResourcesQuery
roleResourcesQuery,
directSiteResourcesQuery,
roleSiteResourcesQuery
]);
// Combine all accessible resource IDs
@@ -68,18 +83,25 @@ export async function getUserResources(
...roleResourceResults.map((r) => r.resourceId)
];
if (accessibleResourceIds.length === 0) {
return response(res, {
data: { resources: [] },
success: true,
error: false,
message: "No resources found",
status: HttpCode.OK
});
}
// Combine all accessible site resource IDs
const accessibleSiteResourceIds = [
...directSiteResourceResults.map((r) => r.siteResourceId),
...roleSiteResourceResults.map((r) => r.siteResourceId)
];
// Get resource details for accessible resources
const resourcesData = await db
let resourcesData: Array<{
resourceId: number;
name: string;
fullDomain: string | null;
ssl: boolean;
enabled: boolean;
sso: boolean;
protocol: string;
emailWhitelistEnabled: boolean;
}> = [];
if (accessibleResourceIds.length > 0) {
resourcesData = await db
.select({
resourceId: resources.resourceId,
name: resources.name,
@@ -98,6 +120,40 @@ export async function getUserResources(
eq(resources.enabled, true)
)
);
}
// Get site resource details for accessible site resources
let siteResourcesData: Array<{
siteResourceId: number;
name: string;
destination: string;
mode: string;
protocol: string | null;
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
}> = [];
if (accessibleSiteResourceIds.length > 0) {
siteResourcesData = await db
.select({
siteResourceId: siteResources.siteResourceId,
name: siteResources.name,
destination: siteResources.destination,
mode: siteResources.mode,
protocol: siteResources.protocol,
enabled: siteResources.enabled,
alias: siteResources.alias,
aliasAddress: siteResources.aliasAddress
})
.from(siteResources)
.where(
and(
inArray(siteResources.siteResourceId, accessibleSiteResourceIds),
eq(siteResources.orgId, orgId),
eq(siteResources.enabled, true)
)
);
}
// Check for password, pincode, and whitelist protection for each resource
const resourcesWithAuth = await Promise.all(
@@ -161,8 +217,26 @@ export async function getUserResources(
})
);
// Format site resources
const siteResourcesFormatted = siteResourcesData.map((siteResource) => {
return {
siteResourceId: siteResource.siteResourceId,
name: siteResource.name,
destination: siteResource.destination,
mode: siteResource.mode,
protocol: siteResource.protocol,
enabled: siteResource.enabled,
alias: siteResource.alias,
aliasAddress: siteResource.aliasAddress,
type: 'site' as const
};
});
return response(res, {
data: { resources: resourcesWithAuth },
data: {
resources: resourcesWithAuth,
siteResources: siteResourcesFormatted
},
success: true,
error: false,
message: "User resources retrieved successfully",
@@ -190,5 +264,16 @@ export type GetUserResourcesResponse = {
protected: boolean;
protocol: string;
}>;
siteResources: Array<{
siteResourceId: number;
name: string;
destination: string;
mode: string;
protocol: string | null;
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
type: 'site';
}>;
};
};

View File

@@ -6,7 +6,7 @@ import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { eq, and } from "drizzle-orm";
import { eq, and, count } from "drizzle-orm";
import { getUniqueSiteName } from "../../db/names";
import { addPeer } from "../gerbil/peers";
import { fromError } from "zod-validation-error";
@@ -288,7 +288,6 @@ export async function createSite(
const niceId = await getUniqueSiteName(orgId);
let newSite: Site | undefined;
let numSites: Site[] | undefined;
await db.transaction(async (trx) => {
if (type == "newt") {
[newSite] = await trx
@@ -443,20 +442,9 @@ export async function createSite(
});
}
numSites = await trx
.select()
.from(sites)
.where(eq(sites.orgId, orgId));
await usageService.add(orgId, FeatureId.SITES, 1, trx);
});
if (numSites) {
await usageService.updateCount(
orgId,
FeatureId.SITES,
numSites.length
);
}
if (!newSite) {
return next(
createHttpError(

View File

@@ -64,7 +64,6 @@ export async function deleteSite(
}
let deletedNewtId: string | null = null;
let numSites: Site[] | undefined;
await db.transaction(async (trx) => {
if (site.type == "wireguard") {
@@ -103,19 +102,9 @@ export async function deleteSite(
await trx.delete(sites).where(eq(sites.siteId, siteId));
numSites = await trx
.select()
.from(sites)
.where(eq(sites.orgId, site.orgId));
await usageService.add(site.orgId, FeatureId.SITES, -1, trx);
});
if (numSites) {
await usageService.updateCount(
site.orgId,
FeatureId.SITES,
numSites.length
);
}
// Send termination message outside of transaction to prevent blocking
if (deletedNewtId) {
const payload = {

View File

@@ -1,8 +1,8 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, UserOrg } from "@server/db";
import { db, orgs, UserOrg } from "@server/db";
import { roles, userInvites, userOrgs, users } from "@server/db";
import { eq } from "drizzle-orm";
import { eq, and, inArray, ne } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -14,6 +14,7 @@ import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { build } from "@server/build";
import { assignUserToOrg } from "@server/lib/userOrg";
const acceptInviteBodySchema = z.strictObject({
token: z.string(),
@@ -125,8 +126,22 @@ export async function acceptInvite(
}
}
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, existingInvite.orgId))
.limit(1);
if (!org) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Organization does not exist. Please contact an admin."
)
);
}
let roleId: number;
let totalUsers: UserOrg[] | undefined;
// get the role to make sure it exists
const existingRole = await db
.select()
@@ -146,12 +161,15 @@ export async function acceptInvite(
}
await db.transaction(async (trx) => {
// add the user to the org
await trx.insert(userOrgs).values({
userId: existingUser[0].userId,
orgId: existingInvite.orgId,
roleId: existingInvite.roleId
});
await assignUserToOrg(
org,
{
userId: existingUser[0].userId,
orgId: existingInvite.orgId,
roleId: existingInvite.roleId
},
trx
);
// delete the invite
await trx
@@ -160,25 +178,11 @@ export async function acceptInvite(
await calculateUserClientsForOrgs(existingUser[0].userId, trx);
// Get the total number of users in the org now
totalUsers = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, existingInvite.orgId));
logger.debug(
`User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}. Total users in org: ${totalUsers.length}`
`User ${existingUser[0].userId} accepted invite to org ${existingInvite.orgId}`
);
});
if (totalUsers) {
await usageService.updateCount(
existingInvite.orgId,
FeatureId.USERS,
totalUsers.length
);
}
return response<AcceptInviteResponse>(res, {
data: { accepted: true, orgId: existingInvite.orgId },
success: true,

View File

@@ -6,8 +6,8 @@ import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { db, UserOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import { db, orgs, UserOrg } from "@server/db";
import { and, eq, inArray, ne } from "drizzle-orm";
import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db";
import { generateId } from "@server/auth/sessions/app";
import { usageService } from "@server/lib/billing/usageService";
@@ -16,6 +16,7 @@ import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { isSubscribed } from "#dynamic/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { assignUserToOrg } from "@server/lib/userOrg";
const paramsSchema = z.strictObject({
orgId: z.string().nonempty()
@@ -151,6 +152,21 @@ export async function createOrgUser(
);
}
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Organization not found"
)
);
}
const [idpRes] = await db
.select()
.from(idp)
@@ -172,8 +188,6 @@ export async function createOrgUser(
);
}
let orgUsers: UserOrg[] | undefined;
await db.transaction(async (trx) => {
const [existingUser] = await trx
.select()
@@ -207,15 +221,12 @@ export async function createOrgUser(
);
}
await trx
.insert(userOrgs)
.values({
orgId,
userId: existingUser.userId,
roleId: role.roleId,
autoProvisioned: false
})
.returning();
await assignUserToOrg(org, {
orgId,
userId: existingUser.userId,
roleId: role.roleId,
autoProvisioned: false
}, trx);
} else {
userId = generateId(15);
@@ -233,33 +244,16 @@ export async function createOrgUser(
})
.returning();
await trx
.insert(userOrgs)
.values({
await assignUserToOrg(org, {
orgId,
userId: newUser.userId,
roleId: role.roleId,
autoProvisioned: false
})
.returning();
}, trx);
}
// List all of the users in the org
orgUsers = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, orgId));
await calculateUserClientsForOrgs(userId, trx);
});
if (orgUsers) {
await usageService.updateCount(
orgId,
FeatureId.USERS,
orgUsers.length
);
}
} else {
return next(
createHttpError(HttpCode.BAD_REQUEST, "User type is required")

View File

@@ -1,8 +1,16 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, resources, sites, UserOrg } from "@server/db";
import {
db,
orgs,
resources,
siteResources,
sites,
UserOrg,
userSiteResources
} from "@server/db";
import { userOrgs, userResources, users, userSites } from "@server/db";
import { and, count, eq, exists } from "drizzle-orm";
import { and, count, eq, exists, inArray } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -14,6 +22,7 @@ import { FeatureId } from "@server/lib/billing";
import { build } from "@server/build";
import { UserType } from "@server/types/UserTypes";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { removeUserFromOrg } from "@server/lib/userOrg";
const removeUserSchema = z.strictObject({
userId: z.string(),
@@ -50,16 +59,16 @@ export async function removeUserOrg(
const { userId, orgId } = parsedParams.data;
// get the user first
const user = await db
const [user] = await db
.select()
.from(userOrgs)
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)));
if (!user || user.length === 0) {
if (!user) {
return next(createHttpError(HttpCode.NOT_FOUND, "User not found"));
}
if (user[0].isOwner) {
if (user.isOwner) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
@@ -68,56 +77,20 @@ export async function removeUserOrg(
);
}
let userCount: UserOrg[] | undefined;
const [org] = await db
.select()
.from(orgs)
.where(eq(orgs.orgId, orgId))
.limit(1);
if (!org) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
);
}
await db.transaction(async (trx) => {
await trx
.delete(userOrgs)
.where(
and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId))
);
await db.delete(userResources).where(
and(
eq(userResources.userId, userId),
exists(
db
.select()
.from(resources)
.where(
and(
eq(
resources.resourceId,
userResources.resourceId
),
eq(resources.orgId, orgId)
)
)
)
)
);
await db.delete(userSites).where(
and(
eq(userSites.userId, userId),
exists(
db
.select()
.from(sites)
.where(
and(
eq(sites.siteId, userSites.siteId),
eq(sites.orgId, orgId)
)
)
)
)
);
userCount = await trx
.select()
.from(userOrgs)
.where(eq(userOrgs.orgId, orgId));
await removeUserFromOrg(org, userId, trx);
// if (build === "saas") {
// const [rootUser] = await trx
@@ -139,14 +112,6 @@ export async function removeUserOrg(
await calculateUserClientsForOrgs(userId, trx);
});
if (userCount) {
await usageService.updateCount(
orgId,
FeatureId.USERS,
userCount.length
);
}
return response(res, {
data: null,
success: true,

View File

@@ -6,6 +6,7 @@ import { redirect } from "next/navigation";
import { getTranslations } from "next-intl/server";
import { getCachedOrgUser } from "@app/lib/api/getCachedOrgUser";
import { getCachedOrg } from "@app/lib/api/getCachedOrg";
import { build } from "@server/build";
type BillingSettingsProps = {
children: React.ReactNode;
@@ -17,6 +18,9 @@ export default async function BillingSettingsPage({
params
}: BillingSettingsProps) {
const { orgId } = await params;
if (build !== "saas") {
redirect(`/${orgId}/settings`);
}
const user = await verifySession();
@@ -40,6 +44,10 @@ export default async function BillingSettingsPage({
redirect(`/${orgId}`);
}
if (!(org?.org?.isBillingOrg && orgUser?.isOwner)) {
redirect(`/${orgId}`);
}
const t = await getTranslations();
return (

View File

@@ -110,37 +110,42 @@ const planOptions: PlanOption[] = [
// Tier limits mapping derived from limit sets
const tierLimits: Record<
Tier | "basic",
{ users: number; sites: number; domains: number; remoteNodes: number }
{ users: number; sites: number; domains: number; remoteNodes: number; organizations: number }
> = {
basic: {
users: freeLimitSet[FeatureId.USERS]?.value ?? 0,
sites: freeLimitSet[FeatureId.SITES]?.value ?? 0,
domains: freeLimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: freeLimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0
remoteNodes: freeLimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: freeLimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
},
tier1: {
users: tier1LimitSet[FeatureId.USERS]?.value ?? 0,
sites: tier1LimitSet[FeatureId.SITES]?.value ?? 0,
domains: tier1LimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: tier1LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0
remoteNodes: tier1LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: tier1LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
},
tier2: {
users: tier2LimitSet[FeatureId.USERS]?.value ?? 0,
sites: tier2LimitSet[FeatureId.SITES]?.value ?? 0,
domains: tier2LimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: tier2LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0
remoteNodes: tier2LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: tier2LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
},
tier3: {
users: tier3LimitSet[FeatureId.USERS]?.value ?? 0,
sites: tier3LimitSet[FeatureId.SITES]?.value ?? 0,
domains: tier3LimitSet[FeatureId.DOMAINS]?.value ?? 0,
remoteNodes: tier3LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0
remoteNodes: tier3LimitSet[FeatureId.REMOTE_EXIT_NODES]?.value ?? 0,
organizations: tier3LimitSet[FeatureId.ORGINIZATIONS]?.value ?? 0
},
enterprise: {
users: 0, // Custom for enterprise
sites: 0, // Custom for enterprise
domains: 0, // Custom for enterprise
remoteNodes: 0 // Custom for enterprise
remoteNodes: 0, // Custom for enterprise
organizations: 0 // Custom for enterprise
}
};
@@ -179,6 +184,7 @@ export default function BillingPage() {
const SITES = "sites";
const DOMAINS = "domains";
const REMOTE_EXIT_NODES = "remoteExitNodes";
const ORGINIZATIONS = "organizations";
// Confirmation dialog state
const [showConfirmDialog, setShowConfirmDialog] = useState(false);
@@ -619,6 +625,16 @@ export default function BillingPage() {
});
}
// Check organizations
const organizationsUsage = getUsageValue(ORGINIZATIONS);
if (limits.organizations > 0 && organizationsUsage > limits.organizations) {
violations.push({
feature: "Organizations",
currentUsage: organizationsUsage,
newLimit: limits.organizations
});
}
return violations;
};
@@ -752,7 +768,7 @@ export default function BillingPage() {
<div className="text-sm text-muted-foreground mb-3">
{t("billingMaximumLimits") || "Maximum Limits"}
</div>
<InfoSections cols={4}>
<InfoSections cols={5}>
<InfoSection>
<InfoSectionTitle className="flex items-center gap-1 text-xs">
{t("billingUsers") || "Users"}
@@ -855,6 +871,41 @@ export default function BillingPage() {
)}
</InfoSectionContent>
</InfoSection>
<InfoSection>
<InfoSectionTitle className="flex items-center gap-1 text-xs">
{t("billingOrganizations") ||
"Organizations"}
</InfoSectionTitle>
<InfoSectionContent className="text-sm">
{isOverLimit(ORGINIZATIONS) ? (
<Tooltip>
<TooltipTrigger className="flex items-center gap-1">
<AlertTriangle className="h-3 w-3 text-orange-400" />
<span className={cn(
"text-orange-600 dark:text-orange-400 font-medium"
)}>
{getLimitValue(ORGINIZATIONS) ??
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(ORGINIZATIONS) !==
null && "orgs"}
</span>
</TooltipTrigger>
<TooltipContent>
<p>{t("billingUsageExceedsLimit", { current: getUsageValue(ORGINIZATIONS), limit: getLimitValue(ORGINIZATIONS) ?? 0 }) || `Current usage (${getUsageValue(ORGINIZATIONS)}) exceeds limit (${getLimitValue(ORGINIZATIONS)})`}</p>
</TooltipContent>
</Tooltip>
) : (
<>
{getLimitValue(ORGINIZATIONS) ??
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(ORGINIZATIONS) !==
null && "orgs"}
</>
)}
</InfoSectionContent>
</InfoSection>
<InfoSection>
<InfoSectionTitle className="flex items-center gap-1 text-xs">
{t("billingRemoteNodes") ||
@@ -872,7 +923,7 @@ export default function BillingPage() {
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(REMOTE_EXIT_NODES) !==
null && "remote nodes"}
null && "nodes"}
</span>
</TooltipTrigger>
<TooltipContent>
@@ -885,7 +936,7 @@ export default function BillingPage() {
t("billingUnlimited") ??
"∞"}{" "}
{getLimitValue(REMOTE_EXIT_NODES) !==
null && "remote nodes"}
null && "nodes"}
</>
)}
</InfoSectionContent>
@@ -1016,6 +1067,17 @@ export default function BillingPage() {
"Domains"}
</span>
</div>
<div className="flex items-center gap-2">
<Check className="h-4 w-4 text-green-600" />
<span>
{
tierLimits[pendingTier.tier]
.organizations
}{" "}
{t("billingOrganizations") ||
"Organizations"}
</span>
</div>
<div className="flex items-center gap-2">
<Check className="h-4 w-4 text-green-600" />
<span>

View File

@@ -4,6 +4,8 @@ import { redirect } from "next/navigation";
import { cache } from "react";
import { getTranslations } from "next-intl/server";
import { build } from "@server/build";
import { getCachedOrgUser } from "@app/lib/api/getCachedOrgUser";
import { getCachedOrg } from "@app/lib/api/getCachedOrg";
type LicensesSettingsProps = {
children: React.ReactNode;
@@ -27,6 +29,26 @@ export default async function LicensesSetingsLayoutProps({
redirect(`/`);
}
let orgUser = null;
try {
const res = await getCachedOrgUser(orgId, user.userId);
orgUser = res.data.data;
} catch {
redirect(`/${orgId}`);
}
let org = null;
try {
const res = await getCachedOrg(orgId);
org = res.data.data;
} catch {
redirect(`/${orgId}`);
}
if (!org?.org?.isBillingOrg || !orgUser?.isOwner) {
redirect(`/${orgId}`);
}
const t = await getTranslations();
return (

View File

@@ -3,11 +3,7 @@ import ConfirmDeleteDialog from "@app/components/ConfirmDeleteDialog";
import { Button } from "@app/components/ui/button";
import { useOrgContext } from "@app/hooks/useOrgContext";
import { toast } from "@app/hooks/useToast";
import {
useState,
useTransition,
useActionState
} from "react";
import { useState, useTransition, useActionState } from "react";
import {
Form,
FormControl,
@@ -54,7 +50,7 @@ export default function GeneralPage() {
return (
<SettingsContainer>
<GeneralSectionForm org={org.org} />
{build !== "saas" && <DeleteForm org={org.org} />}
{!org.org.isBillingOrg && <DeleteForm org={org.org} />}
</SettingsContainer>
);
}

View File

@@ -77,12 +77,16 @@ export default async function SettingsLayout(props: SettingsLayoutProps) {
}
} catch (e) {}
const primaryOrg = orgs.find((o) => o.orgId === params.orgId)?.isPrimaryOrg;
return (
<UserProvider user={user}>
<Layout
orgId={params.orgId}
orgs={orgs}
navItems={orgNavSections(env)}
navItems={orgNavSections(env, {
isPrimaryOrg: primaryOrg
})}
>
{children}
</Layout>

View File

@@ -15,6 +15,7 @@ export default async function Page(props: {
redirect: string | undefined;
email: string | undefined;
fromSmartLogin: string | undefined;
skipVerificationEmail: string | undefined;
}>;
}) {
const searchParams = await props.searchParams;
@@ -75,6 +76,10 @@ export default async function Page(props: {
inviteId={inviteId}
emailParam={searchParams.email}
fromSmartLogin={searchParams.fromSmartLogin === "true"}
skipVerificationEmail={
searchParams.skipVerificationEmail === "true" ||
searchParams.skipVerificationEmail === "1"
}
/>
<p className="text-center text-muted-foreground mt-4">

View File

@@ -31,6 +31,10 @@ export type SidebarNavSection = {
items: SidebarNavItem[];
};
export type OrgNavSectionsOptions = {
isPrimaryOrg?: boolean;
};
// Merged from 'user-management-and-resources' branch
export const orgLangingNavItems: SidebarNavItem[] = [
{
@@ -40,7 +44,10 @@ export const orgLangingNavItems: SidebarNavItem[] = [
}
];
export const orgNavSections = (env?: Env): SidebarNavSection[] => [
export const orgNavSections = (
env?: Env,
options?: OrgNavSectionsOptions
): SidebarNavSection[] => [
{
heading: "sidebarGeneral",
items: [
@@ -214,28 +221,28 @@ export const orgNavSections = (env?: Env): SidebarNavSection[] => [
title: "sidebarSettings",
href: "/{orgId}/settings/general",
icon: <Settings className="size-4 flex-none" />
},
...(build == "saas"
? [
}
]
},
...(build == "saas" && options?.isPrimaryOrg
? [
{
heading: "sidebarBillingAndLicenses",
items: [
{
title: "sidebarBilling",
href: "/{orgId}/settings/billing",
icon: <CreditCard className="size-4 flex-none" />
}
]
: []),
...(build == "saas"
? [
},
{
title: "sidebarEnterpriseLicenses",
href: "/{orgId}/settings/license",
icon: <TicketCheck className="size-4 flex-none" />
}
]
: [])
]
}
}
]
: [])
];
export const adminNavSections = (env?: Env): SidebarNavSection[] => [

View File

@@ -73,7 +73,7 @@ export default async function Page(props: {
if (!orgs.length) {
if (!env.flags.disableUserCreateOrg || user.serverAdmin) {
redirect("/setup");
redirect("/setup?firstOrg");
}
}
@@ -86,6 +86,14 @@ export default async function Page(props: {
targetOrgId = lastOrgCookie;
} else {
let ownedOrg = orgs.find((org) => org.isOwner);
let primaryOrg = orgs.find((org) => org.isPrimaryOrg);
if (!ownedOrg) {
if (primaryOrg) {
ownedOrg = primaryOrg;
} else {
ownedOrg = orgs[0];
}
}
if (!ownedOrg) {
ownedOrg = orgs[0];
}

View File

@@ -4,19 +4,14 @@ import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { toast } from "@app/hooks/useToast";
import { useCallback, useEffect, useState } from "react";
import {
Card,
CardContent,
CardDescription,
CardHeader,
CardTitle
} from "@app/components/ui/card";
import { formatAxiosError } from "@app/lib/api";
import { createApiClient } from "@app/lib/api";
import { useEnvContext } from "@app/hooks/useEnvContext";
import { useUserContext } from "@app/hooks/useUserContext";
import { build } from "@server/build";
import { Separator } from "@/components/ui/separator";
import { z } from "zod";
import { useRouter } from "next/navigation";
import { useRouter, useSearchParams } from "next/navigation";
import { useForm } from "react-hook-form";
import { zodResolver } from "@hookform/resolvers/zod";
import {
@@ -35,7 +30,7 @@ import {
CollapsibleContent,
CollapsibleTrigger
} from "@app/components/ui/collapsible";
import { ChevronsUpDown } from "lucide-react";
import { ArrowRight, ChevronsUpDown } from "lucide-react";
import { cn } from "@app/lib/cn";
type Step = "org" | "site" | "resources";
@@ -45,6 +40,7 @@ export default function StepperForm() {
const [orgIdTaken, setOrgIdTaken] = useState(false);
const t = useTranslations();
const { env } = useEnvContext();
const { user } = useUserContext();
const [loading, setLoading] = useState(false);
const [isChecked, setIsChecked] = useState(false);
@@ -54,7 +50,10 @@ export default function StepperForm() {
const orgSchema = z.object({
orgName: z.string().min(1, { message: t("orgNameRequired") }),
orgId: z.string().min(1, { message: t("orgIdRequired") }),
orgId: z
.string()
.min(1, { message: t("orgIdRequired") })
.max(32, { message: t("orgIdMaxLength") }),
subnet: z.string().min(1, { message: t("subnetRequired") }),
utilitySubnet: z.string().min(1, { message: t("subnetRequired") })
});
@@ -71,12 +70,27 @@ export default function StepperForm() {
const api = createApiClient(useEnvContext());
const router = useRouter();
const searchParams = useSearchParams();
const isFirstOrg = searchParams.get("firstOrg") != null;
// Fetch default subnet on component mount
useEffect(() => {
fetchDefaultSubnet();
}, []);
// Prefill org name and id when build is saas and firstOrg query param is set
useEffect(() => {
if (build !== "saas" || !user || !isFirstOrg) return;
const orgName = user.email
? `${user.email}'s Organization`
: "My Organization";
const orgId = `org_${user.userId}`;
orgForm.setValue("orgName", orgName);
orgForm.setValue("orgId", orgId);
debouncedCheckOrgIdAvailability(orgId);
}, []);
const fetchDefaultSubnet = async () => {
try {
const res = await api.get(`/pick-org-defaults`);
@@ -129,6 +143,16 @@ export default function StepperForm() {
.replace(/^-+|-+$/g, "");
};
const sanitizeOrgId = (value: string) => {
return value
.toLowerCase()
.replace(/\s+/g, "-")
.replace(/[^a-z0-9_-]/g, "")
.replace(/-+/g, "-")
.replace(/^-+|-+$/g, "")
.slice(0, 32);
};
async function orgSubmit(values: z.infer<typeof orgSchema>) {
if (orgIdTaken) {
return;
@@ -161,263 +185,254 @@ export default function StepperForm() {
}
return (
<>
<Card>
<CardHeader>
<CardTitle>{t("setupNewOrg")}</CardTitle>
<CardDescription>{t("setupCreate")}</CardDescription>
</CardHeader>
<CardContent>
<section className="space-y-6">
<div className="flex justify-between mb-2">
<div className="flex flex-col items-center">
<div
className={`w-8 h-8 rounded-full flex items-center justify-center mb-2 ${
currentStep === "org"
? "bg-primary text-primary-foreground"
: "bg-muted text-muted-foreground"
}`}
>
1
</div>
<span
className={`text-sm font-medium ${
currentStep === "org"
? "text-primary"
: "text-muted-foreground"
}`}
>
{t("setupCreateOrg")}
</span>
</div>
<div className="flex flex-col items-center">
<div
className={`w-8 h-8 rounded-full flex items-center justify-center mb-2 ${
currentStep === "site"
? "bg-primary text-primary-foreground"
: "bg-muted text-muted-foreground"
}`}
>
2
</div>
<span
className={`text-sm font-medium ${
currentStep === "site"
? "text-primary"
: "text-muted-foreground"
}`}
>
{t("siteCreate")}
</span>
</div>
<div className="flex flex-col items-center">
<div
className={`w-8 h-8 rounded-full flex items-center justify-center mb-2 ${
currentStep === "resources"
? "bg-primary text-primary-foreground"
: "bg-muted text-muted-foreground"
}`}
>
3
</div>
<span
className={`text-sm font-medium ${
currentStep === "resources"
? "text-primary"
: "text-muted-foreground"
}`}
>
{t("setupCreateResources")}
</span>
</div>
</div>
<section className="space-y-6">
<div>
<h1 className="text-2xl font-semibold tracking-tight">
{t("setupNewOrg")}
</h1>
<p className="text-muted-foreground text-sm mt-1">
{t("setupCreate")}
</p>
</div>
<div className="flex justify-between mb-2">
<div className="flex flex-col items-center">
<div
className={`w-8 h-8 rounded-full flex items-center justify-center mb-2 ${
currentStep === "org"
? "bg-primary text-primary-foreground"
: "bg-muted text-muted-foreground"
}`}
>
1
</div>
<span
className={`text-sm font-medium ${
currentStep === "org"
? "text-primary"
: "text-muted-foreground"
}`}
>
{t("setupCreateOrg")}
</span>
</div>
<div className="flex flex-col items-center">
<div
className={`w-8 h-8 rounded-full flex items-center justify-center mb-2 ${
currentStep === "site"
? "bg-primary text-primary-foreground"
: "bg-muted text-muted-foreground"
}`}
>
2
</div>
<span
className={`text-sm font-medium ${
currentStep === "site"
? "text-primary"
: "text-muted-foreground"
}`}
>
{t("siteCreate")}
</span>
</div>
<div className="flex flex-col items-center">
<div
className={`w-8 h-8 rounded-full flex items-center justify-center mb-2 ${
currentStep === "resources"
? "bg-primary text-primary-foreground"
: "bg-muted text-muted-foreground"
}`}
>
3
</div>
<span
className={`text-sm font-medium ${
currentStep === "resources"
? "text-primary"
: "text-muted-foreground"
}`}
>
{t("setupCreateResources")}
</span>
</div>
</div>
<Separator />
<Separator />
{currentStep === "org" && (
<Form {...orgForm}>
<form
onSubmit={orgForm.handleSubmit(orgSubmit)}
className="space-y-4"
>
<FormField
control={orgForm.control}
name="orgName"
render={({ field }) => (
<FormItem>
<FormLabel>
{t("setupOrgName")}
</FormLabel>
<FormControl>
<Input
type="text"
{...field}
onChange={(e) => {
// Prevent "/" in orgName input
const sanitizedValue =
e.target.value.replace(
/\//g,
"-"
);
const orgId =
generateId(
sanitizedValue
);
orgForm.setValue(
"orgId",
orgId
);
orgForm.setValue(
"orgName",
sanitizedValue
);
debouncedCheckOrgIdAvailability(
orgId
);
}}
value={field.value.replace(
/\//g,
"-"
)}
/>
</FormControl>
<FormMessage />
<FormDescription>
{t("orgDisplayName")}
</FormDescription>
</FormItem>
)}
/>
<FormField
control={orgForm.control}
name="orgId"
render={({ field }) => (
<FormItem>
<FormLabel>
{t("orgId")}
</FormLabel>
<FormControl>
<Input
type="text"
{...field}
/>
</FormControl>
<FormMessage />
<FormDescription>
{t(
"setupIdentifierMessage"
)}
</FormDescription>
</FormItem>
)}
/>
{currentStep === "org" && (
<Form {...orgForm}>
<form
onSubmit={orgForm.handleSubmit(orgSubmit)}
className="space-y-4"
>
<FormField
control={orgForm.control}
name="orgName"
render={({ field }) => (
<FormItem>
<FormLabel>{t("setupOrgName")}</FormLabel>
<FormControl>
<Input
type="text"
{...field}
onChange={(e) => {
// Prevent "/" in orgName input
const sanitizedValue =
e.target.value.replace(
/\//g,
"-"
);
const orgId =
generateId(sanitizedValue);
orgForm.setValue(
"orgId",
orgId
);
orgForm.setValue(
"orgName",
sanitizedValue
);
debouncedCheckOrgIdAvailability(
orgId
);
}}
value={field.value.replace(
/\//g,
"-"
)}
/>
</FormControl>
<FormMessage />
<FormDescription>
{t("orgDisplayName")}
</FormDescription>
</FormItem>
)}
/>
<FormField
control={orgForm.control}
name="orgId"
render={({ field }) => (
<FormItem>
<FormLabel>{t("orgId")}</FormLabel>
<FormControl>
<Input
type="text"
{...field}
onChange={(e) => {
const value = sanitizeOrgId(
e.target.value
);
field.onChange(value);
setOrgIdTaken(false);
if (value) {
debouncedCheckOrgIdAvailability(
value
);
}
}}
/>
</FormControl>
<FormMessage />
<FormDescription>
{t("setupIdentifierMessage")}
</FormDescription>
</FormItem>
)}
/>
<Collapsible
open={isAdvancedOpen}
onOpenChange={setIsAdvancedOpen}
className="space-y-2"
<Collapsible
open={isAdvancedOpen}
onOpenChange={setIsAdvancedOpen}
className="space-y-2"
>
<div className="flex items-center justify-between space-x-4">
<CollapsibleTrigger asChild>
<Button
type="button"
variant="text"
size="sm"
className="p-0 flex items-center justify-between w-full"
>
<div className="flex items-center justify-between space-x-4">
<CollapsibleTrigger asChild>
<Button
type="button"
variant="text"
size="sm"
className="p-0 flex items-center justify-between w-full"
>
<h4 className="text-sm">
{t("advancedSettings")}
</h4>
<div>
<ChevronsUpDown className="h-4 w-4" />
<span className="sr-only">
{t("toggle")}
</span>
</div>
</Button>
</CollapsibleTrigger>
<h4 className="text-sm">
{t("advancedSettings")}
</h4>
<div>
<ChevronsUpDown className="h-4 w-4" />
<span className="sr-only">
{t("toggle")}
</span>
</div>
<CollapsibleContent className="space-y-4">
<FormField
control={orgForm.control}
name="subnet"
render={({ field }) => (
<FormItem>
<FormLabel>
{t(
"setupSubnetAdvanced"
)}
</FormLabel>
<FormControl>
<Input
type="text"
{...field}
/>
</FormControl>
<FormMessage />
<FormDescription>
{t(
"setupSubnetDescription"
)}
</FormDescription>
</FormItem>
</Button>
</CollapsibleTrigger>
</div>
<CollapsibleContent className="space-y-4">
<FormField
control={orgForm.control}
name="subnet"
render={({ field }) => (
<FormItem>
<FormLabel>
{t("setupSubnetAdvanced")}
</FormLabel>
<FormControl>
<Input type="text" {...field} />
</FormControl>
<FormMessage />
<FormDescription>
{t("setupSubnetDescription")}
</FormDescription>
</FormItem>
)}
/>
<FormField
control={orgForm.control}
name="utilitySubnet"
render={({ field }) => (
<FormItem>
<FormLabel>
{t("setupUtilitySubnet")}
</FormLabel>
<FormControl>
<Input type="text" {...field} />
</FormControl>
<FormMessage />
<FormDescription>
{t(
"setupUtilitySubnetDescription"
)}
/>
</FormDescription>
</FormItem>
)}
/>
</CollapsibleContent>
</Collapsible>
<FormField
control={orgForm.control}
name="utilitySubnet"
render={({ field }) => (
<FormItem>
<FormLabel>
{t(
"setupUtilitySubnet"
)}
</FormLabel>
<FormControl>
<Input
type="text"
{...field}
/>
</FormControl>
<FormMessage />
<FormDescription>
{t(
"setupUtilitySubnetDescription"
)}
</FormDescription>
</FormItem>
)}
/>
</CollapsibleContent>
</Collapsible>
{orgIdTaken && !orgCreated ? (
<Alert variant="destructive">
<AlertDescription>
{t("setupErrorIdentifier")}
</AlertDescription>
</Alert>
) : null}
{orgIdTaken && !orgCreated ? (
<Alert variant="destructive">
<AlertDescription>
{t("setupErrorIdentifier")}
</AlertDescription>
</Alert>
) : null}
{/* Error Alert removed, errors now shown as toast */}
{/* Error Alert removed, errors now shown as toast */}
<div className="flex justify-end">
<Button
type="submit"
loading={loading}
disabled={loading || orgIdTaken}
>
{t("setupCreateOrg")}
</Button>
</div>
</form>
</Form>
)}
</section>
</CardContent>
</Card>
</>
<div className="flex justify-end">
<Button
type="submit"
loading={loading}
disabled={loading || orgIdTaken}
>
{t("setupCreateOrg")}
<ArrowRight className="ml-2 h-4 w-4" />
</Button>
</div>
</form>
</Form>
)}
</section>
);
}

View File

@@ -189,10 +189,12 @@ export function LayoutSidebar({
<div className="w-full border-t border-border" />
<div className="p-4 pt-1 flex flex-col shrink-0">
{canShowProductUpdates && (
{canShowProductUpdates ? (
<div className="mb-3">
<ProductUpdates isCollapsed={isSidebarCollapsed} />
</div>
) : (
<div className="mb-3"></div>
)}
{build === "enterprise" && (

View File

@@ -58,6 +58,18 @@ type Resource = {
siteName?: string | null;
};
type SiteResource = {
siteResourceId: number;
name: string;
destination: string;
mode: string;
protocol: string | null;
enabled: boolean;
alias: string | null;
aliasAddress: string | null;
type: 'site';
};
type MemberResourcesPortalProps = {
orgId: string;
};
@@ -334,7 +346,9 @@ export default function MemberResourcesPortal({
const { toast } = useToast();
const [resources, setResources] = useState<Resource[]>([]);
const [siteResources, setSiteResources] = useState<SiteResource[]>([]);
const [filteredResources, setFilteredResources] = useState<Resource[]>([]);
const [filteredSiteResources, setFilteredSiteResources] = useState<SiteResource[]>([]);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const [searchQuery, setSearchQuery] = useState("");
@@ -360,7 +374,9 @@ export default function MemberResourcesPortal({
if (response.data.success) {
setResources(response.data.data.resources);
setSiteResources(response.data.data.siteResources || []);
setFilteredResources(response.data.data.resources);
setFilteredSiteResources(response.data.data.siteResources || []);
} else {
setError("Failed to load resources");
}
@@ -417,17 +433,61 @@ export default function MemberResourcesPortal({
setFilteredResources(filtered);
// Filter and sort site resources
const filteredSites = siteResources.filter(
(resource) =>
resource.name
.toLowerCase()
.includes(searchQuery.toLowerCase()) ||
resource.destination
.toLowerCase()
.includes(searchQuery.toLowerCase())
);
// Sort site resources
filteredSites.sort((a, b) => {
switch (sortBy) {
case "name-asc":
return a.name.localeCompare(b.name);
case "name-desc":
return b.name.localeCompare(a.name);
case "domain-asc":
case "domain-desc":
// Sort by destination for site resources
const destCompare = sortBy === "domain-asc"
? a.destination.localeCompare(b.destination)
: b.destination.localeCompare(a.destination);
return destCompare;
case "status-enabled":
return b.enabled ? 1 : -1;
case "status-disabled":
return a.enabled ? 1 : -1;
default:
return a.name.localeCompare(b.name);
}
});
setFilteredSiteResources(filteredSites);
// Reset to first page when search/sort changes
setCurrentPage(1);
}, [resources, searchQuery, sortBy]);
}, [resources, siteResources, searchQuery, sortBy]);
// Calculate pagination
const totalPages = Math.ceil(filteredResources.length / itemsPerPage);
const totalItems = filteredResources.length + filteredSiteResources.length;
const totalPages = Math.ceil(totalItems / itemsPerPage);
const startIndex = (currentPage - 1) * itemsPerPage;
const paginatedResources = filteredResources.slice(
startIndex,
startIndex + itemsPerPage
);
const remainingSlots = itemsPerPage - paginatedResources.length;
const paginatedSiteResources = remainingSlots > 0
? filteredSiteResources.slice(
Math.max(0, startIndex - filteredResources.length),
Math.max(0, startIndex - filteredResources.length) + remainingSlots
)
: [];
const handleOpenResource = (resource: Resource) => {
// Open the resource in a new tab
@@ -575,7 +635,7 @@ export default function MemberResourcesPortal({
</div>
{/* Resources Content */}
{filteredResources.length === 0 ? (
{filteredResources.length === 0 && filteredSiteResources.length === 0 ? (
/* Enhanced Empty State */
<Card>
<CardContent className="flex flex-col items-center justify-center py-20 text-center">
@@ -623,9 +683,20 @@ export default function MemberResourcesPortal({
</Card>
) : (
<>
{/* Resources Grid */}
<div className="grid gap-5 grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-4 auto-cols-fr">
{paginatedResources.map((resource) => (
{/* Public Resources Section */}
{paginatedResources.length > 0 && (
<>
<div className="mb-4">
<h3 className="text-lg font-semibold text-foreground flex items-center gap-2">
<Globe className="h-5 w-5" />
Public Resources
</h3>
<p className="text-sm text-muted-foreground mt-1">
Web applications and services accessible via browser
</p>
</div>
<div className="grid gap-5 grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-4 auto-cols-fr mb-8">
{paginatedResources.map((resource) => (
<Card key={resource.resourceId}>
<div className="p-6">
<div className="flex items-center justify-between gap-3">
@@ -702,13 +773,167 @@ export default function MemberResourcesPortal({
</Card>
))}
</div>
</>
)}
{/* Private Resources (Site Resources) Section */}
{paginatedSiteResources.length > 0 && (
<>
<div className="mb-4">
<h3 className="text-lg font-semibold text-foreground flex items-center gap-2">
<Combine className="h-5 w-5" />
Private Resources
</h3>
<p className="text-sm text-muted-foreground mt-1">
Internal network resources accessible via client
</p>
</div>
<div className="grid gap-5 grid-cols-1 md:grid-cols-2 lg:grid-cols-3 xl:grid-cols-3 2xl:grid-cols-4 auto-cols-fr mb-8">
{paginatedSiteResources.map((siteResource) => (
<Card key={siteResource.siteResourceId}>
<div className="p-6">
<div className="flex items-center justify-between gap-3">
<div className="flex items-center min-w-0 flex-1 gap-3 overflow-hidden">
<TooltipProvider>
<Tooltip>
<TooltipTrigger className="min-w-0 max-w-full">
<CardTitle className="text-lg font-bold text-foreground truncate group-hover:text-primary transition-colors">
{siteResource.name}
</CardTitle>
</TooltipTrigger>
<TooltipContent>
<p className="max-w-xs break-words">
{siteResource.name}
</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
<div className="flex-shrink-0">
<InfoPopup>
<div className="space-y-2 text-sm">
<div className="text-xs font-medium mb-1.5">Resource Details</div>
<div>
<span className="font-medium">Mode:</span>
<span className="ml-2 text-muted-foreground capitalize">
{siteResource.mode}
</span>
</div>
{siteResource.protocol && (
<div>
<span className="font-medium">Protocol:</span>
<span className="ml-2 text-muted-foreground uppercase">
{siteResource.protocol}
</span>
</div>
)}
{siteResource.alias && (
<div>
<span className="font-medium">Alias:</span>
<span className="ml-2 text-muted-foreground">
{siteResource.alias}
</span>
</div>
)}
{siteResource.aliasAddress && (
<div>
<span className="font-medium">Alias Address:</span>
<span className="ml-2 text-muted-foreground">
{siteResource.aliasAddress}
</span>
</div>
)}
<div>
<span className="font-medium">Status:</span>
<span className={`ml-2 ${siteResource.enabled ? 'text-green-600' : 'text-red-600'}`}>
{siteResource.enabled ? 'Enabled' : 'Disabled'}
</span>
</div>
</div>
</InfoPopup>
</div>
</div>
<div className="mt-3">
{siteResource.alias ? (
<>
{/* Alias as primary */}
<div className="flex items-center gap-2 mb-1">
<div className="text-base font-semibold text-foreground text-left truncate flex-1">
{siteResource.alias}
</div>
<Button
variant="ghost"
size="icon"
className="h-8 w-8 text-muted-foreground"
onClick={() => {
navigator.clipboard.writeText(
siteResource.alias!
);
toast({
title: "Copied to clipboard",
description:
"Resource alias has been copied to your clipboard.",
duration: 2000
});
}}
>
<Copy className="h-4 w-4" />
</Button>
</div>
{/* Destination as secondary */}
<div className="text-xs text-muted-foreground truncate">
{siteResource.destination}
</div>
</>
) : (
/* Destination as primary when no alias */
<div className="flex items-center gap-2">
<div className="text-sm text-muted-foreground font-medium text-left truncate flex-1">
{siteResource.destination}
</div>
<Button
variant="ghost"
size="icon"
className="h-8 w-8 text-muted-foreground"
onClick={() => {
navigator.clipboard.writeText(
siteResource.destination
);
toast({
title: "Copied to clipboard",
description:
"Resource destination has been copied to your clipboard.",
duration: 2000
});
}}
>
<Copy className="h-4 w-4" />
</Button>
</div>
)}
</div>
</div>
<div className="p-6 pt-0 mt-auto">
<div className="flex items-center justify-center py-2 px-4 bg-muted/50 rounded text-sm text-muted-foreground">
<Combine className="h-3.5 w-3.5 mr-2" />
Requires Client Connection
</div>
</div>
</Card>
))}
</div>
</>
)}
{/* Pagination Controls */}
<PaginationControls
currentPage={currentPage}
totalPages={totalPages}
onPageChange={handlePageChange}
totalItems={filteredResources.length}
totalItems={totalItems}
itemsPerPage={itemsPerPage}
/>
</>

View File

@@ -20,12 +20,13 @@ import {
TooltipProvider,
TooltipTrigger
} from "@app/components/ui/tooltip";
import { Badge } from "@app/components/ui/badge";
import { useEnvContext } from "@app/hooks/useEnvContext";
import { cn } from "@app/lib/cn";
import { ListUserOrgsResponse } from "@server/routers/org";
import { Check, ChevronsUpDown, Plus, Building2, Users } from "lucide-react";
import { useRouter } from "next/navigation";
import { useState } from "react";
import { usePathname, useRouter } from "next/navigation";
import { useMemo, useState } from "react";
import { useUserContext } from "@app/hooks/useUserContext";
import { useTranslations } from "next-intl";
@@ -43,11 +44,23 @@ export function OrgSelector({
const { user } = useUserContext();
const [open, setOpen] = useState(false);
const router = useRouter();
const pathname = usePathname();
const { env } = useEnvContext();
const t = useTranslations();
const selectedOrg = orgs?.find((org) => org.orgId === orgId);
const sortedOrgs = useMemo(() => {
if (!orgs?.length) return orgs ?? [];
return [...orgs].sort((a, b) => {
const aPrimary = Boolean(a.isPrimaryOrg);
const bPrimary = Boolean(b.isPrimaryOrg);
if (aPrimary && !bPrimary) return -1;
if (!aPrimary && bPrimary) return 1;
return 0;
});
}, [orgs]);
const orgSelectorContent = (
<Popover open={open} onOpenChange={setOpen}>
<PopoverTrigger asChild>
@@ -124,25 +137,39 @@ export function OrgSelector({
)}
<CommandGroup heading={t("orgs")} className="py-2">
<CommandList>
{orgs?.map((org) => (
{sortedOrgs.map((org) => (
<CommandItem
key={org.orgId}
onSelect={() => {
setOpen(false);
router.push(`/${org.orgId}/settings`);
const newPath = pathname.replace(
/^\/[^/]+/,
`/${org.orgId}`
);
router.push(newPath);
}}
className="mx-2 rounded-md"
>
<div className="flex items-center justify-center w-8 h-8 rounded-lg bg-muted mr-3">
<Users className="h-4 w-4 text-muted-foreground" />
</div>
<div className="flex flex-col flex-1">
<span className="font-medium">
<div className="flex flex-col flex-1 min-w-0">
<span className="font-medium truncate">
{org.name}
</span>
<span className="text-xs text-muted-foreground">
{t("organization")}
</span>
<div className="flex items-center gap-2 min-w-0">
<span className="text-xs text-muted-foreground font-mono truncate">
{org.orgId}
</span>
{org.isPrimaryOrg && (
<Badge
variant="outline"
className="shrink-0 text-[10px] px-1.5 py-0 font-medium ml-auto"
>
{t("primary")}
</Badge>
)}
</div>
</div>
<Check
className={cn(

View File

@@ -72,6 +72,7 @@ type SignupFormProps = {
inviteToken?: string;
emailParam?: string;
fromSmartLogin?: boolean;
skipVerificationEmail?: boolean;
};
const formSchema = z
@@ -103,7 +104,8 @@ export default function SignupForm({
inviteId,
inviteToken,
emailParam,
fromSmartLogin = false
fromSmartLogin = false,
skipVerificationEmail = false
}: SignupFormProps) {
const router = useRouter();
const { env } = useEnvContext();
@@ -147,7 +149,8 @@ export default function SignupForm({
inviteToken,
termsAcceptedTimestamp: termsAgreedAt,
marketingEmailConsent:
build === "saas" ? marketingEmailConsent : undefined
build === "saas" ? marketingEmailConsent : undefined,
skipVerificationEmail: skipVerificationEmail || undefined
})
.catch((e) => {
console.error(e);