diff --git a/server/private/routers/policy/createResourcePolicy.ts b/server/private/routers/policy/createResourcePolicy.ts index 6cf710810..dc6780616 100644 --- a/server/private/routers/policy/createResourcePolicy.ts +++ b/server/private/routers/policy/createResourcePolicy.ts @@ -27,7 +27,7 @@ const createResourcePolicyParamsSchema = z.strictObject({ const createResourcePolicyBodySchema = z.strictObject({ name: z.string().min(1).max(255), sso: z.boolean(), - skipToIdpId: z.string().optional(), + skipToIdpId: z.int().positive().optional(), roleIds: z .array(z.string().transform(Number).pipe(z.int().positive())) .optional() @@ -150,7 +150,9 @@ export async function createResourcePolicy( .select() .from(users) .innerJoin(userOrgs, eq(userOrgs.userId, users.userId)) - .where(and(inArray(users.userId, userIds))); + .where( + and(eq(userOrgs.orgId, orgId), inArray(users.userId, userIds)) + ); const niceId = await getUniqueResourcePolicyName(orgId); diff --git a/server/routers/policy/setResourcePolicyAccessControl.ts b/server/routers/policy/setResourcePolicyAccessControl.ts index 430c9e59f..98f43f5fa 100644 --- a/server/routers/policy/setResourcePolicyAccessControl.ts +++ b/server/routers/policy/setResourcePolicyAccessControl.ts @@ -1,20 +1,29 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; -import { db } from "@server/db"; -import { userResources } from "@server/db"; +import { + db, + idp, + idpOrg, + resourcePolicies, + rolePolicies, + roles, + userOrgs, + users +} from "@server/db"; +import { userPolicies } from "@server/db"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; import { fromError } from "zod-validation-error"; -import { eq } from "drizzle-orm"; +import { and, eq, inArray, ne, type InferInsertModel } from "drizzle-orm"; import { OpenAPITags, registry } from "@server/openApi"; const setResourcePolicyAcccessControlBodySchema = z.strictObject({ sso: z.boolean(), userIds: z.array(z.string()), roleIds: z.array(z.int().positive()), - skipToIdpId: z.string().optional() + skipToIdpId: z.int().positive().optional() }); const setResourcePolicyAccessControlParamsSchema = z.strictObject({ @@ -58,7 +67,7 @@ export async function setResourceUsers( ); } - const { userIds } = parsedBody.data; + const { userIds, roleIds, sso, skipToIdpId: idpId } = parsedBody.data; const parsedParams = setResourcePolicyAccessControlParamsSchema.safeParse(req.params); @@ -73,26 +82,146 @@ export async function setResourceUsers( const { resourcePolicyId } = parsedParams.data; - await db.transaction(async (trx) => { - await trx - .delete(userResources) - .where(eq(userResources.resourceId, resourceId)); + const [policy] = await db + .select() + .from(resourcePolicies) + .where(eq(resourcePolicies.resourcePolicyId, resourcePolicyId)) + .limit(1); - const newUserResources = await Promise.all( - userIds.map((userId) => - trx - .insert(userResources) - .values({ userId, resourceId }) - .returning() + if (!policy) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Resource policy not found" ) ); + } + + // Check if Identity provider in `skipToIdpId` exists + if (idpId) { + const [provider] = await db + .select() + .from(idp) + .innerJoin(idpOrg, eq(idpOrg.idpId, idp.idpId)) + .where( + and(eq(idp.idpId, idpId), eq(idpOrg.orgId, policy.orgId)) + ) + .limit(1); + + if (!provider) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Identity provider not found in this organization" + ) + ); + } + } + + // Check if any of the roleIds are admin roles + const rolesToCheck = await db + .select() + .from(roles) + .where( + and( + inArray(roles.roleId, roleIds), + eq(roles.orgId, policy.orgId) + ) + ); + + const hasAdminRole = rolesToCheck.some((role) => role.isAdmin); + + if (hasAdminRole) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Admin role cannot be assigned to resources" + ) + ); + } + + // Get all admin role IDs for this org to exclude from deletion + const adminRoles = await db + .select() + .from(roles) + .where(and(eq(roles.isAdmin, true), eq(roles.orgId, policy.orgId))); + const adminRoleIds = adminRoles.map((role) => role.roleId); + + const existingUsers = await db + .select() + .from(users) + .innerJoin(userOrgs, eq(userOrgs.userId, users.userId)) + .where( + and( + eq(userOrgs.orgId, policy.orgId), + inArray(users.userId, userIds) + ) + ); + + const existingRoles = await db + .select() + .from(roles) + .where( + and( + eq(roles.orgId, policy.orgId), + inArray(roles.roleId, roleIds) + ) + ); + + await db.transaction(async (trx) => { + // Update SSO status + await trx + .update(resourcePolicies) + .set({ + sso, + idpId + }) + .where(eq(resourcePolicies.resourcePolicyId, resourcePolicyId)); + + // Update roles + if (adminRoleIds.length > 0) { + await trx.delete(rolePolicies).where( + and( + eq(rolePolicies.resourcePolicyId, resourcePolicyId), + ne(rolePolicies.roleId, adminRoleIds[0]) // delete all but the admin role + ) + ); + } else { + await trx + .delete(rolePolicies) + .where(eq(rolePolicies.resourcePolicyId, resourcePolicyId)); + } + + const rolesToAdd = existingRoles.map(({ roleId }) => ({ + roleId, + resourcePolicyId + })); + + if (rolesToAdd.length > 0) { + await trx.insert(rolePolicies).values(rolesToAdd); + } + + // Update users + await trx + .delete(userPolicies) + .where(eq(userPolicies.resourcePolicyId, resourcePolicyId)); + + const usersToAdd = existingUsers.map(({ user }) => ({ + userId: user.userId, + resourcePolicyId: resourcePolicyId + })); + + if (usersToAdd.length > 0) { + await trx.insert(userPolicies).values(usersToAdd); + } }); + return response(res, { data: {}, success: true, error: false, - message: "Users set for resource successfully", - status: HttpCode.CREATED + message: "Resource policy succesfully updated", + status: HttpCode.OK }); } catch (error) { logger.error(error);