This commit is contained in:
Owen
2025-10-04 18:36:44 -07:00
parent 3123f858bb
commit c2c907852d
320 changed files with 35785 additions and 2984 deletions

View File

@@ -112,7 +112,7 @@ export async function createOidcIdp(
});
});
const redirectUrl = generateOidcRedirectUrl(idpId as number);
const redirectUrl = await generateOidcRedirectUrl(idpId as number);
return response<CreateIdpResponse>(res, {
data: {

View File

@@ -12,6 +12,7 @@ import { OpenAPITags, registry } from "@server/openApi";
const paramsSchema = z
.object({
orgId: z.string().optional(), // Optional; used with org idp in saas
idpId: z.coerce.number()
})
.strict();

View File

@@ -10,10 +10,12 @@ import { idp, idpOidcConfig, idpOrg } from "@server/db";
import { and, eq } from "drizzle-orm";
import * as arctic from "arctic";
import { generateOidcRedirectUrl } from "@server/lib/idp/generateRedirectUrl";
import cookie from "cookie";
import jsonwebtoken from "jsonwebtoken";
import config from "@server/lib/config";
import { decrypt } from "@server/lib/crypto";
import { build } from "@server/build";
import { getOrgTierData } from "@server/routers/private/billing";
import { TierId } from "@server/lib/private/billing/tiers";
const paramsSchema = z
.object({
@@ -27,6 +29,10 @@ const bodySchema = z
})
.strict();
const querySchema = z.object({
orgId: z.string().optional() // check what actuall calls it
});
const ensureTrailingSlash = (url: string): string => {
return url;
};
@@ -65,6 +71,18 @@ export async function generateOidcUrl(
const { redirectUrl: postAuthRedirectUrl } = parsedBody.data;
const parsedQuery = querySchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error).toString()
)
);
}
const { orgId } = parsedQuery.data;
const [existingIdp] = await db
.select()
.from(idp)
@@ -80,6 +98,36 @@ export async function generateOidcUrl(
);
}
if (orgId) {
const [idpOrgLink] = await db
.select()
.from(idpOrg)
.where(and(eq(idpOrg.idpId, idpId), eq(idpOrg.orgId, orgId)))
.limit(1);
if (!idpOrgLink) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"IdP not found for the organization"
)
);
}
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"This organization's current plan does not support this feature."
)
);
}
}
}
const parsedScopes = existingIdp.idpOidcConfig.scopes
.split(" ")
.map((scope) => {
@@ -100,7 +148,12 @@ export async function generateOidcUrl(
key
);
const redirectUrl = generateOidcRedirectUrl(idpId);
const redirectUrl = await generateOidcRedirectUrl(idpId, orgId);
logger.debug("OIDC client info", {
decryptedClientId,
decryptedClientSecret,
redirectUrl
});
const client = new arctic.OAuth2Client(
decryptedClientId,
decryptedClientSecret,
@@ -116,7 +169,6 @@ export async function generateOidcUrl(
codeVerifier,
parsedScopes
);
logger.debug("Generated OIDC URL", { url });
const stateJwt = jsonwebtoken.sign(
{

View File

@@ -8,4 +8,4 @@ export * from "./getIdp";
export * from "./createIdpOrgPolicy";
export * from "./deleteIdpOrgPolicy";
export * from "./listIdpOrgPolicies";
export * from "./updateIdpOrgPolicy";
export * from "./updateIdpOrgPolicy";

View File

@@ -30,6 +30,8 @@ import {
} from "@server/auth/sessions/app";
import { decrypt } from "@server/lib/crypto";
import { UserType } from "@server/types/UserTypes";
import { FeatureId } from "@server/lib/private/billing";
import { usageService } from "@server/lib/private/billing/usageService";
const ensureTrailingSlash = (url: string): string => {
return url;
@@ -47,6 +49,10 @@ const bodySchema = z.object({
storedState: z.string().nonempty()
});
const querySchema = z.object({
loginPageId: z.coerce.number().optional()
});
export type ValidateOidcUrlCallbackResponse = {
redirectUrl: string;
};
@@ -79,6 +85,18 @@ export async function validateOidcCallback(
);
}
const parsedQuery = querySchema.safeParse(req.query);
if (!parsedQuery.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedQuery.error).toString()
)
);
}
const { loginPageId } = parsedQuery.data;
const { storedState, code, state: expectedState } = parsedBody.data;
const [existingIdp] = await db
@@ -107,7 +125,11 @@ export async function validateOidcCallback(
key
);
const redirectUrl = generateOidcRedirectUrl(existingIdp.idp.idpId);
const redirectUrl = await generateOidcRedirectUrl(
existingIdp.idp.idpId,
undefined,
loginPageId
);
const client = new arctic.OAuth2Client(
decryptedClientId,
decryptedClientSecret,
@@ -380,12 +402,14 @@ export async function validateOidcCallback(
}
// Update roles for existing auto-provisioned orgs where the role has changed
const orgsToUpdate = autoProvisionedOrgs.filter((currentOrg) => {
const newOrg = userOrgInfo.find(
(newOrg) => newOrg.orgId === currentOrg.orgId
);
return newOrg && newOrg.roleId !== currentOrg.roleId;
});
const orgsToUpdate = autoProvisionedOrgs.filter(
(currentOrg) => {
const newOrg = userOrgInfo.find(
(newOrg) => newOrg.orgId === currentOrg.orgId
);
return newOrg && newOrg.roleId !== currentOrg.roleId;
}
);
if (orgsToUpdate.length > 0) {
for (const org of orgsToUpdate) {
@@ -441,6 +465,14 @@ export async function validateOidcCallback(
}
});
for (const orgCount of orgUserCounts) {
await usageService.updateDaily(
orgCount.orgId,
FeatureId.USERS,
orgCount.userCount
);
}
const token = generateSessionToken();
const sess = await createSession(token, existingUserId!);
const isSecure = req.protocol === "https";