diff --git a/server/auth/actions.ts b/server/auth/actions.ts index 450e3f42b..6cdc4fa0a 100644 --- a/server/auth/actions.ts +++ b/server/auth/actions.ts @@ -109,6 +109,9 @@ export enum ActionsEnum { listApiKeyActions = "listApiKeyActions", listApiKeys = "listApiKeys", getApiKey = "getApiKey", + createSiteProvisioningKey = "createSiteProvisioningKey", + listSiteProvisioningKeys = "listSiteProvisioningKeys", + deleteSiteProvisioningKey = "deleteSiteProvisioningKey", getCertificate = "getCertificate", restartCertificate = "restartCertificate", billing = "billing", diff --git a/server/db/pg/schema/privateSchema.ts b/server/db/pg/schema/privateSchema.ts index c9d7cc907..63d6bcd60 100644 --- a/server/db/pg/schema/privateSchema.ts +++ b/server/db/pg/schema/privateSchema.ts @@ -7,7 +7,8 @@ import { bigint, real, text, - index + index, + primaryKey } from "drizzle-orm/pg-core"; import { InferSelectModel } from "drizzle-orm"; import { @@ -89,7 +90,9 @@ export const subscriptions = pgTable("subscriptions", { export const subscriptionItems = pgTable("subscriptionItems", { subscriptionItemId: serial("subscriptionItemId").primaryKey(), - stripeSubscriptionItemId: varchar("stripeSubscriptionItemId", { length: 255 }), + stripeSubscriptionItemId: varchar("stripeSubscriptionItemId", { + length: 255 + }), subscriptionId: varchar("subscriptionId", { length: 255 }) .notNull() .references(() => subscriptions.subscriptionId, { @@ -329,13 +332,44 @@ export const approvals = pgTable("approvals", { }); export const bannedEmails = pgTable("bannedEmails", { - email: varchar("email", { length: 255 }).primaryKey(), + email: varchar("email", { length: 255 }).primaryKey() }); export const bannedIps = pgTable("bannedIps", { - ip: varchar("ip", { length: 255 }).primaryKey(), + ip: varchar("ip", { length: 255 }).primaryKey() }); +export const siteProvisioningKeys = pgTable("siteProvisioningKeys", { + siteProvisioningKeyId: varchar("siteProvisioningKeyId", { + length: 255 + }).primaryKey(), + name: varchar("name", { length: 255 }).notNull(), + siteProvisioningKeyHash: text("siteProvisioningKeyHash").notNull(), + lastChars: varchar("lastChars", { length: 4 }).notNull(), + createdAt: varchar("dateCreated", { length: 255 }).notNull() +}); + +export const siteProvisioningKeyOrg = pgTable( + "siteProvisioningKeyOrg", + { + siteProvisioningKeyId: varchar("siteProvisioningKeyId", { + length: 255 + }) + .notNull() + .references(() => siteProvisioningKeys.siteProvisioningKeyId, { + onDelete: "cascade" + }), + orgId: varchar("orgId", { length: 255 }) + .notNull() + .references(() => orgs.orgId, { onDelete: "cascade" }) + }, + (table) => [ + primaryKey({ + columns: [table.siteProvisioningKeyId, table.orgId] + }) + ] +); + export type Approval = InferSelectModel; export type Limit = InferSelectModel; export type Account = InferSelectModel; diff --git a/server/db/sqlite/schema/privateSchema.ts b/server/db/sqlite/schema/privateSchema.ts index 8baeb5220..ac2662e27 100644 --- a/server/db/sqlite/schema/privateSchema.ts +++ b/server/db/sqlite/schema/privateSchema.ts @@ -2,6 +2,7 @@ import { InferSelectModel } from "drizzle-orm"; import { index, integer, + primaryKey, real, sqliteTable, text @@ -318,7 +319,6 @@ export const approvals = sqliteTable("approvals", { .notNull() }); - export const bannedEmails = sqliteTable("bannedEmails", { email: text("email").primaryKey() }); @@ -327,6 +327,33 @@ export const bannedIps = sqliteTable("bannedIps", { ip: text("ip").primaryKey() }); +export const siteProvisioningKeys = sqliteTable("siteProvisioningKeys", { + siteProvisioningKeyId: text("siteProvisioningKeyId").primaryKey(), + name: text("name").notNull(), + siteProvisioningKeyHash: text("siteProvisioningKeyHash").notNull(), + lastChars: text("lastChars").notNull(), + createdAt: text("dateCreated").notNull() +}); + +export const siteProvisioningKeyOrg = sqliteTable( + "siteProvisioningKeyOrg", + { + siteProvisioningKeyId: text("siteProvisioningKeyId") + .notNull() + .references(() => siteProvisioningKeys.siteProvisioningKeyId, { + onDelete: "cascade" + }), + orgId: text("orgId") + .notNull() + .references(() => orgs.orgId, { onDelete: "cascade" }) + }, + (table) => [ + primaryKey({ + columns: [table.siteProvisioningKeyId, table.orgId] + }) + ] +); + export type Approval = InferSelectModel; export type Limit = InferSelectModel; export type Account = InferSelectModel; diff --git a/server/middlewares/index.ts b/server/middlewares/index.ts index 6437c90e2..0f485e637 100644 --- a/server/middlewares/index.ts +++ b/server/middlewares/index.ts @@ -24,6 +24,7 @@ export * from "./verifyClientAccess"; export * from "./integration"; export * from "./verifyUserHasAction"; export * from "./verifyApiKeyAccess"; +export * from "./verifySiteProvisioningKeyAccess"; export * from "./verifyDomainAccess"; export * from "./verifyUserIsOrgOwner"; export * from "./verifySiteResourceAccess"; diff --git a/server/middlewares/verifySiteProvisioningKeyAccess.ts b/server/middlewares/verifySiteProvisioningKeyAccess.ts new file mode 100644 index 000000000..e0d446de6 --- /dev/null +++ b/server/middlewares/verifySiteProvisioningKeyAccess.ts @@ -0,0 +1,131 @@ +import { Request, Response, NextFunction } from "express"; +import { db, userOrgs, siteProvisioningKeys, siteProvisioningKeyOrg } from "@server/db"; +import { and, eq } from "drizzle-orm"; +import createHttpError from "http-errors"; +import HttpCode from "@server/types/HttpCode"; +import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy"; + +export async function verifySiteProvisioningKeyAccess( + req: Request, + res: Response, + next: NextFunction +) { + try { + const userId = req.user!.userId; + const siteProvisioningKeyId = req.params.siteProvisioningKeyId; + const orgId = req.params.orgId; + + if (!userId) { + return next( + createHttpError(HttpCode.UNAUTHORIZED, "User not authenticated") + ); + } + + if (!orgId) { + return next( + createHttpError(HttpCode.BAD_REQUEST, "Invalid organization ID") + ); + } + + if (!siteProvisioningKeyId) { + return next( + createHttpError(HttpCode.BAD_REQUEST, "Invalid key ID") + ); + } + + const [row] = await db + .select() + .from(siteProvisioningKeys) + .innerJoin( + siteProvisioningKeyOrg, + and( + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyOrg.siteProvisioningKeyId + ), + eq(siteProvisioningKeyOrg.orgId, orgId) + ) + ) + .where( + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyId + ) + ) + .limit(1); + + if (!row?.siteProvisioningKeys) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Site provisioning key with ID ${siteProvisioningKeyId} not found` + ) + ); + } + + if (!row.siteProvisioningKeyOrg.orgId) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + `Site provisioning key with ID ${siteProvisioningKeyId} does not have an organization ID` + ) + ); + } + + if (!req.userOrg) { + const userOrgRole = await db + .select() + .from(userOrgs) + .where( + and( + eq(userOrgs.userId, userId), + eq( + userOrgs.orgId, + row.siteProvisioningKeyOrg.orgId + ) + ) + ) + .limit(1); + req.userOrg = userOrgRole[0]; + } + + if (!req.userOrg) { + return next( + createHttpError( + HttpCode.FORBIDDEN, + "User does not have access to this organization" + ) + ); + } + + if (req.orgPolicyAllowed === undefined && req.userOrg.orgId) { + const policyCheck = await checkOrgAccessPolicy({ + orgId: req.userOrg.orgId, + userId, + session: req.session + }); + req.orgPolicyAllowed = policyCheck.allowed; + if (!policyCheck.allowed || policyCheck.error) { + return next( + createHttpError( + HttpCode.FORBIDDEN, + "Failed organization access policy check: " + + (policyCheck.error || "Unknown error") + ) + ); + } + } + + const userOrgRoleId = req.userOrg.roleId; + req.userOrgRoleId = userOrgRoleId; + + return next(); + } catch (error) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Error verifying site provisioning key access" + ) + ); + } +} diff --git a/server/routers/external.ts b/server/routers/external.ts index 45ab58bba..90f208863 100644 --- a/server/routers/external.ts +++ b/server/routers/external.ts @@ -15,6 +15,7 @@ import * as accessToken from "./accessToken"; import * as idp from "./idp"; import * as blueprints from "./blueprints"; import * as apiKeys from "./apiKeys"; +import * as siteProvisioning from "./siteProvisioning"; import * as logs from "./auditLogs"; import * as newt from "./newt"; import * as olm from "./olm"; @@ -42,7 +43,8 @@ import { verifyUserIsOrgOwner, verifySiteResourceAccess, verifyOlmAccess, - verifyLimits + verifyLimits, + verifySiteProvisioningKeyAccess } from "@server/middlewares"; import { ActionsEnum } from "@server/auth/actions"; import rateLimit, { ipKeyGenerator } from "express-rate-limit"; @@ -986,6 +988,31 @@ authenticated.get( apiKeys.listRootApiKeys ); +authenticated.put( + `/org/:orgId/site-provisioning-key`, + verifyOrgAccess, + verifyLimits, + verifyUserHasAction(ActionsEnum.createSiteProvisioningKey), + logActionAudit(ActionsEnum.createSiteProvisioningKey), + siteProvisioning.createSiteProvisioningKey +); + +authenticated.get( + `/org/:orgId/site-provisioning-keys`, + verifyOrgAccess, + verifyUserHasAction(ActionsEnum.listSiteProvisioningKeys), + siteProvisioning.listSiteProvisioningKeys +); + +authenticated.delete( + `/org/:orgId/site-provisioning-key/:siteProvisioningKeyId`, + verifyOrgAccess, + verifySiteProvisioningKeyAccess, + verifyUserHasAction(ActionsEnum.deleteSiteProvisioningKey), + logActionAudit(ActionsEnum.deleteSiteProvisioningKey), + siteProvisioning.deleteSiteProvisioningKey +); + authenticated.get( `/api-key/:apiKeyId/actions`, verifyUserIsServerAdmin, diff --git a/server/routers/siteProvisioning/createSiteProvisioningKey.ts b/server/routers/siteProvisioning/createSiteProvisioningKey.ts new file mode 100644 index 000000000..9bb298966 --- /dev/null +++ b/server/routers/siteProvisioning/createSiteProvisioningKey.ts @@ -0,0 +1,108 @@ +import { NextFunction, Request, Response } from "express"; +import { db, siteProvisioningKeyOrg, siteProvisioningKeys } from "@server/db"; +import HttpCode from "@server/types/HttpCode"; +import { z } from "zod"; +import { fromError } from "zod-validation-error"; +import createHttpError from "http-errors"; +import response from "@server/lib/response"; +import moment from "moment"; +import { + generateId, + generateIdFromEntropySize +} from "@server/auth/sessions/app"; +import logger from "@server/logger"; +import { hashPassword } from "@server/auth/password"; + +const paramsSchema = z.object({ + orgId: z.string().nonempty() +}); + +const bodySchema = z.strictObject({ + name: z.string().min(1).max(255) +}); + +export type CreateSiteProvisioningKeyBody = z.infer; + +export type CreateSiteProvisioningKeyResponse = { + siteProvisioningKeyId: string; + orgId: string; + name: string; + siteProvisioningKey: string; + lastChars: string; + createdAt: string; +}; + +export async function createSiteProvisioningKey( + req: Request, + res: Response, + next: NextFunction +): Promise { + const parsedParams = paramsSchema.safeParse(req.params); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const parsedBody = bodySchema.safeParse(req.body); + if (!parsedBody.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedBody.error).toString() + ) + ); + } + + const { orgId } = parsedParams.data; + const { name } = parsedBody.data; + + const siteProvisioningKeyId = `spk-${generateId(15)}`; + const siteProvisioningKey = generateIdFromEntropySize(25); + const siteProvisioningKeyHash = await hashPassword(siteProvisioningKey); + const lastChars = siteProvisioningKey.slice(-4); + const createdAt = moment().toISOString(); + + await db.transaction(async (trx) => { + await trx.insert(siteProvisioningKeys).values({ + siteProvisioningKeyId, + name, + siteProvisioningKeyHash, + createdAt, + lastChars + }); + + await trx.insert(siteProvisioningKeyOrg).values({ + siteProvisioningKeyId, + orgId + }); + }); + + try { + return response(res, { + data: { + siteProvisioningKeyId, + orgId, + name, + siteProvisioningKey, + lastChars, + createdAt + }, + success: true, + error: false, + message: "Site provisioning key created", + status: HttpCode.CREATED + }); + } catch (e) { + logger.error(e); + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Failed to create site provisioning key" + ) + ); + } +} diff --git a/server/routers/siteProvisioning/deleteSiteProvisioningKey.ts b/server/routers/siteProvisioning/deleteSiteProvisioningKey.ts new file mode 100644 index 000000000..d1da01d97 --- /dev/null +++ b/server/routers/siteProvisioning/deleteSiteProvisioningKey.ts @@ -0,0 +1,116 @@ +import { Request, Response, NextFunction } from "express"; +import { z } from "zod"; +import { + db, + siteProvisioningKeyOrg, + siteProvisioningKeys +} from "@server/db"; +import { and, eq } from "drizzle-orm"; +import response from "@server/lib/response"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { fromError } from "zod-validation-error"; + +const paramsSchema = z.object({ + siteProvisioningKeyId: z.string().nonempty(), + orgId: z.string().nonempty() +}); + +export async function deleteSiteProvisioningKey( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedParams = paramsSchema.safeParse(req.params); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const { siteProvisioningKeyId, orgId } = parsedParams.data; + + const [row] = await db + .select() + .from(siteProvisioningKeys) + .where( + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyId + ) + ) + .innerJoin( + siteProvisioningKeyOrg, + and( + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyOrg.siteProvisioningKeyId + ), + eq(siteProvisioningKeyOrg.orgId, orgId) + ) + ) + .limit(1); + + if (!row) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Site provisioning key with ID ${siteProvisioningKeyId} not found` + ) + ); + } + + await db.transaction(async (trx) => { + await trx + .delete(siteProvisioningKeyOrg) + .where( + and( + eq( + siteProvisioningKeyOrg.siteProvisioningKeyId, + siteProvisioningKeyId + ), + eq(siteProvisioningKeyOrg.orgId, orgId) + ) + ); + + const siteProvisioningKeyOrgs = await trx + .select() + .from(siteProvisioningKeyOrg) + .where( + eq( + siteProvisioningKeyOrg.siteProvisioningKeyId, + siteProvisioningKeyId + ) + ); + + if (siteProvisioningKeyOrgs.length === 0) { + await trx + .delete(siteProvisioningKeys) + .where( + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyId + ) + ); + } + }); + + return response(res, { + data: null, + success: true, + error: false, + message: "Site provisioning key deleted successfully", + status: HttpCode.OK + }); + } catch (error) { + logger.error(error); + return next( + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") + ); + } +} diff --git a/server/routers/siteProvisioning/index.ts b/server/routers/siteProvisioning/index.ts new file mode 100644 index 000000000..b3f69f100 --- /dev/null +++ b/server/routers/siteProvisioning/index.ts @@ -0,0 +1,3 @@ +export * from "./createSiteProvisioningKey"; +export * from "./listSiteProvisioningKeys"; +export * from "./deleteSiteProvisioningKey"; diff --git a/server/routers/siteProvisioning/listSiteProvisioningKeys.ts b/server/routers/siteProvisioning/listSiteProvisioningKeys.ts new file mode 100644 index 000000000..65360625c --- /dev/null +++ b/server/routers/siteProvisioning/listSiteProvisioningKeys.ts @@ -0,0 +1,115 @@ +import { + db, + siteProvisioningKeyOrg, + siteProvisioningKeys +} from "@server/db"; +import logger from "@server/logger"; +import HttpCode from "@server/types/HttpCode"; +import response from "@server/lib/response"; +import { NextFunction, Request, Response } from "express"; +import createHttpError from "http-errors"; +import { z } from "zod"; +import { fromError } from "zod-validation-error"; +import { eq } from "drizzle-orm"; + +const paramsSchema = z.object({ + orgId: z.string().nonempty() +}); + +const querySchema = z.object({ + limit: z + .string() + .optional() + .default("1000") + .transform(Number) + .pipe(z.int().positive()), + offset: z + .string() + .optional() + .default("0") + .transform(Number) + .pipe(z.int().nonnegative()) +}); + +function querySiteProvisioningKeys(orgId: string) { + return db + .select({ + siteProvisioningKeyId: + siteProvisioningKeys.siteProvisioningKeyId, + orgId: siteProvisioningKeyOrg.orgId, + lastChars: siteProvisioningKeys.lastChars, + createdAt: siteProvisioningKeys.createdAt, + name: siteProvisioningKeys.name + }) + .from(siteProvisioningKeyOrg) + .innerJoin( + siteProvisioningKeys, + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyOrg.siteProvisioningKeyId + ) + ) + .where(eq(siteProvisioningKeyOrg.orgId, orgId)); +} + +export type ListSiteProvisioningKeysResponse = { + siteProvisioningKeys: Awaited< + ReturnType + >; + pagination: { total: number; limit: number; offset: number }; +}; + +export async function listSiteProvisioningKeys( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedParams = paramsSchema.safeParse(req.params); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error) + ) + ); + } + + const parsedQuery = querySchema.safeParse(req.query); + if (!parsedQuery.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedQuery.error) + ) + ); + } + + const { orgId } = parsedParams.data; + const { limit, offset } = parsedQuery.data; + + const siteProvisioningKeysList = await querySiteProvisioningKeys(orgId) + .limit(limit) + .offset(offset); + + return response(res, { + data: { + siteProvisioningKeys: siteProvisioningKeysList, + pagination: { + total: siteProvisioningKeysList.length, + limit, + offset + } + }, + success: true, + error: false, + message: "Site provisioning keys retrieved successfully", + status: HttpCode.OK + }); + } catch (error) { + logger.error(error); + return next( + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") + ); + } +}