diff --git a/server/db/queries/verifySessionQueries.ts b/server/db/queries/verifySessionQueries.ts index 66f968b02..989e111a7 100644 --- a/server/db/queries/verifySessionQueries.ts +++ b/server/db/queries/verifySessionQueries.ts @@ -1,4 +1,12 @@ -import { db, loginPage, LoginPage, loginPageOrg, Org, orgs, roles } from "@server/db"; +import { + db, + loginPage, + LoginPage, + loginPageOrg, + Org, + orgs, + roles +} from "@server/db"; import { Resource, ResourcePassword, @@ -12,14 +20,12 @@ import { resources, roleResources, sessions, - userOrgRoles, - userOrgs, userResources, users, ResourceHeaderAuthExtendedCompatibility, resourceHeaderAuthExtendedCompatibility } from "@server/db"; -import { and, eq } from "drizzle-orm"; +import { and, eq, inArray } from "drizzle-orm"; export type ResourceWithAuth = { resource: Resource | null; @@ -121,7 +127,7 @@ export async function getRoleName(roleId: number): Promise { */ export async function getRoleResourceAccess( resourceId: number, - roleId: number + roleIds: number[] ) { const roleResourceAccess = await db .select() @@ -129,12 +135,11 @@ export async function getRoleResourceAccess( .where( and( eq(roleResources.resourceId, resourceId), - eq(roleResources.roleId, roleId) + inArray(roleResources.roleId, roleIds) ) - ) - .limit(1); + ); - return roleResourceAccess.length > 0 ? roleResourceAccess[0] : null; + return roleResourceAccess.length > 0 ? roleResourceAccess : null; } /** diff --git a/server/lib/userOrgRoles.ts b/server/lib/userOrgRoles.ts index 5a4d75659..c3db64af3 100644 --- a/server/lib/userOrgRoles.ts +++ b/server/lib/userOrgRoles.ts @@ -1,4 +1,4 @@ -import { db, userOrgRoles } from "@server/db"; +import { db, roles, userOrgRoles } from "@server/db"; import { and, eq } from "drizzle-orm"; /** @@ -20,3 +20,17 @@ export async function getUserOrgRoleIds( ); return rows.map((r) => r.roleId); } + +export async function getUserOrgRoles( + userId: string, + orgId: string +): Promise<{ roleId: number; roleName: string }[]> { + const rows = await db + .select({ roleId: userOrgRoles.roleId, roleName: roles.name }) + .from(userOrgRoles) + .innerJoin(roles, eq(userOrgRoles.roleId, roles.roleId)) + .where( + and(eq(userOrgRoles.userId, userId), eq(userOrgRoles.orgId, orgId)) + ); + return rows; +} diff --git a/server/private/routers/hybrid.ts b/server/private/routers/hybrid.ts index 5ca720594..71fdc7e72 100644 --- a/server/private/routers/hybrid.ts +++ b/server/private/routers/hybrid.ts @@ -124,6 +124,23 @@ const getRoleResourceAccessParamsSchema = z.strictObject({ .pipe(z.int().positive("Resource ID must be a positive integer")) }); +const getResourceAccessParamsSchema = z.strictObject({ + resourceId: z + .string() + .transform(Number) + .pipe(z.int().positive("Resource ID must be a positive integer")) +}); + +const getResourceAccessQuerySchema = z.strictObject({ + roleIds: z + .union([z.array(z.string()), z.string()]) + .transform((val) => + (Array.isArray(val) ? val : [val]) + .map(Number) + .filter((n) => !isNaN(n)) + ) +}); + const getUserResourceAccessParamsSchema = z.strictObject({ userId: z.string().min(1, "User ID is required"), resourceId: z @@ -769,7 +786,7 @@ hybridRouter.get( // Get user organization role hybridRouter.get( - "/user/:userId/org/:orgId/role", + "/user/:userId/org/:orgId/roles", async (req: Request, res: Response, next: NextFunction) => { try { const parsedParams = getUserOrgRoleParamsSchema.safeParse( @@ -805,6 +822,80 @@ hybridRouter.get( ); } + const userOrgRoleRows = await db + .select({ roleId: userOrgRoles.roleId, roleName: roles.name }) + .from(userOrgRoles) + .innerJoin(roles, eq(roles.roleId, userOrgRoles.roleId)) + .where( + and( + eq(userOrgRoles.userId, userId), + eq(userOrgRoles.orgId, orgId) + ) + ); + + return response<{ roleId: number, roleName: string }[]>(res, { + data: userOrgRoleRows, + success: true, + error: false, + message: + userOrgRoleRows.length > 0 + ? "User org roles retrieved successfully" + : "User has no roles in this organization", + status: HttpCode.OK + }); + } catch (error) { + logger.error(error); + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Failed to get user org role" + ) + ); + } + } +); + +// DEPRICATED Get user organization role +// used for backward compatibility with old remote nodes +hybridRouter.get( + "/user/:userId/org/:orgId/role", // <- note the missing s + async (req: Request, res: Response, next: NextFunction) => { + try { + const parsedParams = getUserOrgRoleParamsSchema.safeParse( + req.params + ); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const { userId, orgId } = parsedParams.data; + const remoteExitNode = req.remoteExitNode; + + if (!remoteExitNode || !remoteExitNode.exitNodeId) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Remote exit node not found" + ) + ); + } + + if (await checkExitNodeOrg(remoteExitNode.exitNodeId, orgId)) { + return next( + createHttpError( + HttpCode.UNAUTHORIZED, + "User is not authorized to access this organization" + ) + ); + } + + // get the roles on the user + const userOrgRoleRows = await db .select({ roleId: userOrgRoles.roleId }) .from(userOrgRoles) @@ -817,8 +908,35 @@ hybridRouter.get( const roleIds = userOrgRoleRows.map((r) => r.roleId); - return response(res, { - data: roleIds, + let roleId: number | null = null; + + if (userOrgRoleRows.length === 0) { + // User has no roles in this organization + roleId = null; + } else if (userOrgRoleRows.length === 1) { + // User has exactly one role, return it + roleId = userOrgRoleRows[0].roleId; + } else { + // User has multiple roles + // Check if any of these roles are also assigned to a resource + // If we find a match, prefer that role; otherwise return the first role + // Get all resources that have any of these roles assigned + const roleResourceMatches = await db + .select({ roleId: roleResources.roleId }) + .from(roleResources) + .where(inArray(roleResources.roleId, roleIds)) + .limit(1); + if (roleResourceMatches.length > 0) { + // Return the first role that's also on a resource + roleId = roleResourceMatches[0].roleId; + } else { + // No resource match found, return the first role + roleId = userOrgRoleRows[0].roleId; + } + } + + return response<{ roleId: number | null }>(res, { + data: { roleId }, success: true, error: false, message: @@ -939,7 +1057,9 @@ hybridRouter.get( data: role?.name ?? null, success: true, error: false, - message: role ? "Role name retrieved successfully" : "Role not found", + message: role + ? "Role name retrieved successfully" + : "Role not found", status: HttpCode.OK }); } catch (error) { @@ -1039,6 +1159,101 @@ hybridRouter.get( } ); +// Check if role has access to resource +hybridRouter.get( + "/resource/:resourceId/access", + async (req: Request, res: Response, next: NextFunction) => { + try { + const parsedParams = getResourceAccessParamsSchema.safeParse( + req.params + ); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const { resourceId } = parsedParams.data; + const parsedQuery = getResourceAccessQuerySchema.safeParse( + req.query + ); + const roleIds = parsedQuery.success ? parsedQuery.data.roleIds : []; + + const remoteExitNode = req.remoteExitNode; + + if (!remoteExitNode?.exitNodeId) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Remote exit node not found" + ) + ); + } + + const [resource] = await db + .select() + .from(resources) + .where(eq(resources.resourceId, resourceId)) + .limit(1); + + if ( + await checkExitNodeOrg( + remoteExitNode.exitNodeId, + resource.orgId + ) + ) { + // If the exit node is not allowed for the org, return an error + return next( + createHttpError( + HttpCode.FORBIDDEN, + "Exit node not allowed for this organization" + ) + ); + } + + const roleResourceAccess = await db + .select({ + resourceId: roleResources.resourceId, + roleId: roleResources.roleId + }) + .from(roleResources) + .where( + and( + eq(roleResources.resourceId, resourceId), + inArray(roleResources.roleId, roleIds) + ) + ); + + const result = + roleResourceAccess.length > 0 ? roleResourceAccess : null; + + return response<{ resourceId: number; roleId: number }[] | null>( + res, + { + data: result, + success: true, + error: false, + message: result + ? "Role resource access retrieved successfully" + : "Role resource access not found", + status: HttpCode.OK + } + ); + } catch (error) { + logger.error(error); + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Failed to get role resource access" + ) + ); + } + } +); + // Check if user has direct access to resource hybridRouter.get( "/user/:userId/resource/:resourceId/access", @@ -1937,7 +2152,8 @@ hybridRouter.post( // userAgent: data.userAgent, // TODO: add this // headers: data.body.headers, // query: data.body.query, - originalRequestURL: sanitizeString(logEntry.originalRequestURL) ?? "", + originalRequestURL: + sanitizeString(logEntry.originalRequestURL) ?? "", scheme: sanitizeString(logEntry.scheme) ?? "", host: sanitizeString(logEntry.host) ?? "", path: sanitizeString(logEntry.path) ?? "", diff --git a/server/routers/badger/verifySession.ts b/server/routers/badger/verifySession.ts index 182b35dcb..7ea281c9b 100644 --- a/server/routers/badger/verifySession.ts +++ b/server/routers/badger/verifySession.ts @@ -9,7 +9,7 @@ import { getOrgLoginPage, getUserSessionWithUser } from "@server/db/queries/verifySessionQueries"; -import { getUserOrgRoleIds } from "@server/lib/userOrgRoles"; +import { getUserOrgRoles } from "@server/lib/userOrgRoles"; import { LoginPage, Org, @@ -798,7 +798,8 @@ async function notAllowed( ) { let loginPage: LoginPage | null = null; if (orgId) { - const subscribed = await isSubscribed( // this is fine because the org login page is only a saas feature + const subscribed = await isSubscribed( + // this is fine because the org login page is only a saas feature orgId, tierMatrix.loginPageDomain ); @@ -855,7 +856,10 @@ async function headerAuthChallenged( ) { let loginPage: LoginPage | null = null; if (orgId) { - const subscribed = await isSubscribed(orgId, tierMatrix.loginPageDomain); // this is fine because the org login page is only a saas feature + const subscribed = await isSubscribed( + orgId, + tierMatrix.loginPageDomain + ); // this is fine because the org login page is only a saas feature if (subscribed) { loginPage = await getOrgLoginPage(orgId); } @@ -917,9 +921,9 @@ async function isUserAllowedToAccessResource( return null; } - const userOrgRoleIds = await getUserOrgRoleIds(user.userId, resource.orgId); + const userOrgRoles = await getUserOrgRoles(user.userId, resource.orgId); - if (!userOrgRoleIds.length) { + if (!userOrgRoles.length) { return null; } @@ -935,23 +939,16 @@ async function isUserAllowedToAccessResource( return null; } - const roleNames: string[] = []; - for (const roleId of userOrgRoleIds) { - const roleResourceAccess = await getRoleResourceAccess( - resource.resourceId, - roleId - ); - if (roleResourceAccess) { - const roleName = await getRoleName(roleId); - if (roleName) roleNames.push(roleName); - } - } - if (roleNames.length > 0) { + const roleResourceAccess = await getRoleResourceAccess( + resource.resourceId, + userOrgRoles.map((r) => r.roleId) + ); + if (roleResourceAccess && roleResourceAccess.length > 0) { return { username: user.username, email: user.email, name: user.name, - role: roleNames.join(", ") + role: userOrgRoles.map((r) => r.roleName).join(", ") }; } @@ -961,15 +958,11 @@ async function isUserAllowedToAccessResource( ); if (userResourceAccess) { - const names = await Promise.all( - userOrgRoleIds.map((id) => getRoleName(id)) - ); - const role = names.filter(Boolean).join(", ") || ""; return { username: user.username, email: user.email, name: user.name, - role + role: userOrgRoles.map((r) => r.roleName).join(", ") }; }