diff --git a/server/lib/rebuildClientAssociations.ts b/server/lib/rebuildClientAssociations.ts index a7a134f6d..a9a2759c4 100644 --- a/server/lib/rebuildClientAssociations.ts +++ b/server/lib/rebuildClientAssociations.ts @@ -18,7 +18,7 @@ import { userOrgRoles, userSiteResources } from "@server/db"; -import { and, eq, inArray, ne } from "drizzle-orm"; +import { and, count, eq, inArray, ne } from "drizzle-orm"; import { deletePeer as newtDeletePeer } from "@server/routers/newt/peers"; import { @@ -39,6 +39,11 @@ import { removePeerData, removeTargets as removeSubnetProxyTargets } from "@server/routers/client/targets"; +import { lockManager } from "#dynamic/lib/lock"; + +// TTL for rebuild-association locks. These functions can fan out into many +// peer/proxy updates, so give them a generous window. +const REBUILD_ASSOCIATIONS_LOCK_TTL_MS = 120000; export async function getClientSiteResourceAccess( siteResource: SiteResource, @@ -161,6 +166,23 @@ export async function rebuildClientAssociationsFromSiteResource( pubKey: string | null; subnet: string | null; }[]; +}> { + return await lockManager.withLock( + `rebuild-client-associations:site-resource:${siteResource.siteResourceId}`, + () => rebuildClientAssociationsFromSiteResourceImpl(siteResource, trx), + REBUILD_ASSOCIATIONS_LOCK_TTL_MS + ); +} + +async function rebuildClientAssociationsFromSiteResourceImpl( + siteResource: SiteResource, + trx: Transaction | typeof db = db +): Promise<{ + mergedAllClients: { + clientId: number; + pubKey: string | null; + subnet: string | null; + }[]; }> { logger.debug( `rebuildClientAssociations: [rebuildClientAssociationsFromSiteResource] START siteResourceId=${siteResource.siteResourceId} networkId=${siteResource.networkId} orgId=${siteResource.orgId}` @@ -539,6 +561,29 @@ async function handleMessagesForSiteClients( } } + // get the number of sites on each of these clients so we can log it and make decisions about whether to send messages based on it + const clientSiteCounts: Record = {}; + if (clientsToProcess.size > 0) { + const clientIdsToProcess = Array.from(clientsToProcess.keys()); + const siteCounts = await trx + .select({ + clientId: clientSitesAssociationsCache.clientId, + siteCount: count(clientSitesAssociationsCache.siteId) + }) + .from(clientSitesAssociationsCache) + .where( + inArray( + clientSitesAssociationsCache.clientId, + clientIdsToProcess + ) + ) + .groupBy(clientSitesAssociationsCache.clientId); + + for (const row of siteCounts) { + clientSiteCounts[row.clientId] = Number(row.siteCount); + } + } + for (const client of clientsToProcess.values()) { // UPDATE THE NEWT if (!client.subnet || !client.pubKey) { @@ -582,7 +627,14 @@ async function handleMessagesForSiteClients( } if (isAdd) { - // TODO: if we are in jit mode here should we really be sending this? + if (clientSiteCounts[client.clientId] > 250) { + // skip adding the peer if we have more than 250 sites because we are in jit mode anyway + logger.info( + `rebuildClientAssociations: Client ${client.clientId} has ${clientSiteCounts[client.clientId]} sites so skipping adding peer to newt and olm because it is likely in jit mode` + ); + continue; + } + await initPeerAddHandshake( // this will kick off the add peer process for the client client.clientId, @@ -600,9 +652,24 @@ async function handleMessagesForSiteClients( exitNodeJobs.push(updateClientSiteDestinations(client, trx)); } - await Promise.all(exitNodeJobs); - await Promise.all(newtJobs); // do the servers first to make sure they are ready? - await Promise.all(olmJobs); + Promise.all(exitNodeJobs).catch((error) => { + logger.error( + `rebuildClientAssociations: Error updating client site destinations for site ${site.siteId}:`, + error + ); + }); + Promise.all(newtJobs).catch((error) => { + logger.error( + `rebuildClientAssociations: Error updating Newt peers for site ${site.siteId}:`, + error + ); + }); + Promise.all(olmJobs).catch((error) => { + logger.error( + `rebuildClientAssociations: Error updating Olm peers for site ${site.siteId}:`, + error + ); + }); } interface PeerDestination { @@ -885,6 +952,17 @@ async function handleSubnetProxyTargetUpdates( export async function rebuildClientAssociationsFromClient( client: Client, trx: Transaction | typeof db = db +): Promise { + return await lockManager.withLock( + `rebuild-client-associations:client:${client.clientId}`, + () => rebuildClientAssociationsFromClientImpl(client, trx), + REBUILD_ASSOCIATIONS_LOCK_TTL_MS + ); +} + +async function rebuildClientAssociationsFromClientImpl( + client: Client, + trx: Transaction | typeof db = db ): Promise { let newSiteResourceIds: number[] = []; @@ -1157,6 +1235,12 @@ async function handleMessagesForClientSites( const olmJobs: Promise[] = []; const exitNodeJobs: Promise[] = []; + const totalSitesOnClient = await trx + .select({ count: count(clientSitesAssociationsCache.siteId) }) + .from(clientSitesAssociationsCache) + .where(eq(clientSitesAssociationsCache.clientId, client.clientId)) + .then((rows) => Number(rows[0].count)); + for (const siteData of sitesData) { const site = siteData.sites; const exitNode = siteData.exitNodes; @@ -1217,7 +1301,14 @@ async function handleMessagesForClientSites( continue; } - // TODO: if we are in jit mode here should we really be sending this? + if (totalSitesOnClient > 250) { + // skip adding the site if we have more than 250 because we are in jit mode anyway + logger.info( + `rebuildClientAssociations: Client ${client.clientId} has ${totalSitesOnClient} sites so skipping adding peer to newt and olm because it is likely in jit mode` + ); + continue; + } + await initPeerAddHandshake( // this will kick off the add peer process for the client client.clientId, @@ -1245,9 +1336,24 @@ async function handleMessagesForClientSites( ); } - await Promise.all(exitNodeJobs); - await Promise.all(newtJobs); - await Promise.all(olmJobs); + Promise.all(exitNodeJobs).catch((error) => { + logger.error( + `rebuildClientAssociations: Error updating client site destinations for client ${client.clientId}:`, + error + ); + }); + Promise.all(newtJobs).catch((error) => { + logger.error( + `rebuildClientAssociations: Error updating Newt peers for client ${client.clientId}:`, + error + ); + }); + Promise.all(olmJobs).catch((error) => { + logger.error( + `rebuildClientAssociations: Error updating Olm peers for client ${client.clientId}:`, + error + ); + }); } async function handleMessagesForClientResources( @@ -1528,3 +1634,195 @@ async function handleMessagesForClientResources( await Promise.all([...proxyJobs, ...olmJobs]); } + +export type ClientAssociationsCacheVerification = { + clientId: number; + consistent: boolean; + // What permissions say the cache should contain + expectedSiteResourceIds: number[]; + expectedSiteIds: number[]; + // What the cache currently contains + actualSiteResourceIds: number[]; + actualSiteIds: number[]; + // Diff + missingSiteResourceIds: number[]; // present in expected, missing from cache + extraSiteResourceIds: number[]; // present in cache, not in expected + missingSiteIds: number[]; + extraSiteIds: number[]; +}; + +// verifyClientAssociationsCache walks the same permission-derivation logic as +// rebuildClientAssociationsFromClient but does NOT modify the database. It +// returns the expected vs actual cache contents and a boolean indicating +// whether the cache is in sync with what permissions imply. +export async function verifyClientAssociationsCache( + client: Client, + trx: Transaction | typeof db = db +): Promise { + let newSiteResourceIds: number[] = []; + + // 1. Direct client associations + const directSiteResources = await trx + .select({ siteResourceId: clientSiteResources.siteResourceId }) + .from(clientSiteResources) + .innerJoin( + siteResources, + eq(siteResources.siteResourceId, clientSiteResources.siteResourceId) + ) + .where( + and( + eq(clientSiteResources.clientId, client.clientId), + eq(siteResources.orgId, client.orgId) + ) + ); + + newSiteResourceIds.push( + ...directSiteResources.map((r) => r.siteResourceId) + ); + + // 2. User-based and role-based access (if client has a userId) + if (client.userId) { + const userSiteResourceIds = await trx + .select({ siteResourceId: userSiteResources.siteResourceId }) + .from(userSiteResources) + .innerJoin( + siteResources, + eq( + siteResources.siteResourceId, + userSiteResources.siteResourceId + ) + ) + .where( + and( + eq(userSiteResources.userId, client.userId), + eq(siteResources.orgId, client.orgId) + ) + ); + + newSiteResourceIds.push( + ...userSiteResourceIds.map((r) => r.siteResourceId) + ); + + const roleIds = await trx + .select({ roleId: userOrgRoles.roleId }) + .from(userOrgRoles) + .where( + and( + eq(userOrgRoles.userId, client.userId), + eq(userOrgRoles.orgId, client.orgId) + ) + ) + .then((rows) => rows.map((row) => row.roleId)); + + if (roleIds.length > 0) { + const roleSiteResourceIds = await trx + .select({ siteResourceId: roleSiteResources.siteResourceId }) + .from(roleSiteResources) + .innerJoin( + siteResources, + eq( + siteResources.siteResourceId, + roleSiteResources.siteResourceId + ) + ) + .where( + and( + inArray(roleSiteResources.roleId, roleIds), + eq(siteResources.orgId, client.orgId) + ) + ); + + newSiteResourceIds.push( + ...roleSiteResourceIds.map((r) => r.siteResourceId) + ); + } + } + + newSiteResourceIds = Array.from(new Set(newSiteResourceIds)); + + const newSiteResources = + newSiteResourceIds.length > 0 + ? await trx + .select() + .from(siteResources) + .where( + inArray(siteResources.siteResourceId, newSiteResourceIds) + ) + : []; + + const networkIds = Array.from( + new Set( + newSiteResources + .map((sr) => sr.networkId) + .filter((id): id is number => id !== null) + ) + ); + const newSiteIds = + networkIds.length > 0 + ? await trx + .select({ siteId: siteNetworks.siteId }) + .from(siteNetworks) + .where(inArray(siteNetworks.networkId, networkIds)) + .then((rows) => + Array.from(new Set(rows.map((r) => r.siteId))) + ) + : []; + + // Read the existing cache state + const existingResourceAssociations = await trx + .select({ + siteResourceId: clientSiteResourcesAssociationsCache.siteResourceId + }) + .from(clientSiteResourcesAssociationsCache) + .where( + eq(clientSiteResourcesAssociationsCache.clientId, client.clientId) + ); + const existingSiteResourceIds = existingResourceAssociations.map( + (r) => r.siteResourceId + ); + + const existingSiteAssociations = await trx + .select({ siteId: clientSitesAssociationsCache.siteId }) + .from(clientSitesAssociationsCache) + .where(eq(clientSitesAssociationsCache.clientId, client.clientId)); + const existingSiteIds = existingSiteAssociations.map((s) => s.siteId); + + const expectedSiteResourceSet = new Set(newSiteResourceIds); + const actualSiteResourceSet = new Set(existingSiteResourceIds); + const expectedSiteSet = new Set(newSiteIds); + const actualSiteSet = new Set(existingSiteIds); + + const missingSiteResourceIds = newSiteResourceIds.filter( + (id) => !actualSiteResourceSet.has(id) + ); + const extraSiteResourceIds = existingSiteResourceIds.filter( + (id) => !expectedSiteResourceSet.has(id) + ); + const missingSiteIds = newSiteIds.filter((id) => !actualSiteSet.has(id)); + const extraSiteIds = existingSiteIds.filter( + (id) => !expectedSiteSet.has(id) + ); + + const consistent = + missingSiteResourceIds.length === 0 && + extraSiteResourceIds.length === 0 && + missingSiteIds.length === 0 && + extraSiteIds.length === 0; + + return { + clientId: client.clientId, + consistent, + expectedSiteResourceIds: Array.from(expectedSiteResourceSet).sort( + (a, b) => a - b + ), + expectedSiteIds: Array.from(expectedSiteSet).sort((a, b) => a - b), + actualSiteResourceIds: Array.from(actualSiteResourceSet).sort( + (a, b) => a - b + ), + actualSiteIds: Array.from(actualSiteSet).sort((a, b) => a - b), + missingSiteResourceIds: missingSiteResourceIds.sort((a, b) => a - b), + extraSiteResourceIds: extraSiteResourceIds.sort((a, b) => a - b), + missingSiteIds: missingSiteIds.sort((a, b) => a - b), + extraSiteIds: extraSiteIds.sort((a, b) => a - b) + }; +} diff --git a/server/private/routers/external.ts b/server/private/routers/external.ts index a2667daa1..a48cc3813 100644 --- a/server/private/routers/external.ts +++ b/server/private/routers/external.ts @@ -31,6 +31,7 @@ import * as siteProvisioning from "#private/routers/siteProvisioning"; import * as eventStreamingDestination from "#private/routers/eventStreamingDestination"; import * as alertRule from "#private/routers/alertRule"; import * as healthChecks from "#private/routers/healthChecks"; +import * as client from "@server/routers/client"; import { verifyOrgAccess, @@ -775,3 +776,9 @@ authenticated.get( verifyUserHasAction(ActionsEnum.getTarget), healthChecks.getHealthCheckStatusHistory ); + +authenticated.get( + "/client/:clientId/verify-associations-cache", + verifyClientAccess, + client.verifyClientAssociationsCache +); diff --git a/server/private/routers/loginPage/upsertLoginPageBranding.ts b/server/private/routers/loginPage/upsertLoginPageBranding.ts index 7e0da2c53..958cdeb43 100644 --- a/server/private/routers/loginPage/upsertLoginPageBranding.ts +++ b/server/private/routers/loginPage/upsertLoginPageBranding.ts @@ -26,7 +26,6 @@ import logger from "@server/logger"; import { fromError } from "zod-validation-error"; import { eq, InferInsertModel } from "drizzle-orm"; import { build } from "@server/build"; -import { validateLocalPath } from "@app/lib/validateLocalPath"; import config from "#private/lib/config"; const paramsSchema = z.strictObject({ @@ -34,79 +33,7 @@ const paramsSchema = z.strictObject({ }); const bodySchema = z.strictObject({ - logoUrl: z - .union([ - z.literal(""), - z - .string() - .superRefine(async (urlOrPath, ctx) => { - const parseResult = z.url().safeParse(urlOrPath); - if (!parseResult.success) { - if (build !== "enterprise") { - ctx.addIssue({ - code: "custom", - message: "Must be a valid URL" - }); - return; - } else { - try { - validateLocalPath(urlOrPath); - } catch (error) { - ctx.addIssue({ - code: "custom", - message: "Must be either a valid image URL or a valid pathname starting with `/` and not containing query parameters, `..` or `*`" - }); - } finally { - return; - } - } - } - - try { - const response = await fetch(urlOrPath, { - method: "HEAD" - }).catch(() => { - // If HEAD fails (CORS or method not allowed), try GET - return fetch(urlOrPath, { method: "GET" }); - }); - - if (response.status !== 200) { - ctx.addIssue({ - code: "custom", - message: `Failed to load image. Please check that the URL is accessible.` - }); - return; - } - - const contentType = - response.headers.get("content-type") ?? ""; - if (!contentType.startsWith("image/")) { - ctx.addIssue({ - code: "custom", - message: `URL does not point to an image. Please provide a URL to an image file (e.g., .png, .jpg, .svg).` - }); - return; - } - } catch (error) { - let errorMessage = - "Unable to verify image URL. Please check that the URL is accessible and points to an image file."; - - if (error instanceof TypeError && error.message.includes("fetch")) { - errorMessage = - "Network error: Unable to reach the URL. Please check your internet connection and verify the URL is correct."; - } else if (error instanceof Error) { - errorMessage = `Error verifying URL: ${error.message}`; - } - - ctx.addIssue({ - code: "custom", - message: errorMessage - }); - } - }) - ]) - .transform((val) => (val === "" ? null : val)) - .nullish(), + logoUrl: z.string().optional(), logoWidth: z.coerce.number().min(1), logoHeight: z.coerce.number().min(1), resourceTitle: z.string(), diff --git a/server/routers/client/index.ts b/server/routers/client/index.ts index e195d1c52..145cdd306 100644 --- a/server/routers/client/index.ts +++ b/server/routers/client/index.ts @@ -10,3 +10,4 @@ export * from "./listUserDevices"; export * from "./updateClient"; export * from "./getClient"; export * from "./createUserClient"; +export * from "./verifyClientAssociationsCache"; diff --git a/server/routers/client/verifyClientAssociationsCache.ts b/server/routers/client/verifyClientAssociationsCache.ts new file mode 100644 index 000000000..6b701ded3 --- /dev/null +++ b/server/routers/client/verifyClientAssociationsCache.ts @@ -0,0 +1,83 @@ +import { Request, Response, NextFunction } from "express"; +import { z } from "zod"; +import { db } from "@server/db"; +import { clients } from "@server/db"; +import { eq } from "drizzle-orm"; +import response from "@server/lib/response"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { fromError } from "zod-validation-error"; +import { OpenAPITags, registry } from "@server/openApi"; +import { verifyClientAssociationsCache as verifyClientAssociationsCacheLib } from "@server/lib/rebuildClientAssociations"; + +const paramsSchema = z.strictObject({ + clientId: z.string().transform(Number).pipe(z.int().positive()) +}); + +registry.registerPath({ + method: "get", + path: "/client/{clientId}/verify-associations-cache", + description: + "Read-only check of whether the client's site/site-resource association cache matches what the current permissions imply.", + tags: [OpenAPITags.Client], + request: { + params: paramsSchema + }, + responses: {} +}); + +export async function verifyClientAssociationsCache( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedParams = paramsSchema.safeParse(req.params); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const { clientId } = parsedParams.data; + + const [client] = await db + .select() + .from(clients) + .where(eq(clients.clientId, clientId)) + .limit(1); + + if (!client) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Client with ID ${clientId} not found` + ) + ); + } + + const report = await verifyClientAssociationsCacheLib(client); + + return response(res, { + data: report, + success: true, + error: false, + message: report.consistent + ? "Client association cache is consistent" + : "Client association cache is INCONSISTENT", + status: HttpCode.OK + }); + } catch (error) { + logger.error(error); + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Failed to verify client association cache" + ) + ); + } +} diff --git a/src/app/[orgId]/settings/clients/user/[niceId]/general/page.tsx b/src/app/[orgId]/settings/clients/user/[niceId]/general/page.tsx index c08551865..af2d5fdad 100644 --- a/src/app/[orgId]/settings/clients/user/[niceId]/general/page.tsx +++ b/src/app/[orgId]/settings/clients/user/[niceId]/general/page.tsx @@ -153,6 +153,37 @@ export default function GeneralPage() { const [approvalId, setApprovalId] = useState(null); const [isRefreshing, setIsRefreshing] = useState(false); const [, startTransition] = useTransition(); + const [cacheCheck, setCacheCheck] = useState(null); + const [isCheckingCache, setIsCheckingCache] = useState(false); + + const handleVerifyCache = async () => { + if (!client.clientId) return; + setIsCheckingCache(true); + try { + const res = await api.get( + `/client/${client.clientId}/verify-associations-cache` + ); + setCacheCheck(res.data.data); + } catch (e) { + toast({ + variant: "destructive", + title: "Cache check failed", + description: formatAxiosError(e, "Failed to verify cache") + }); + } finally { + setIsCheckingCache(false); + } + }; const { env } = useEnvContext(); const showApprovalFeatures = @@ -844,6 +875,65 @@ export default function GeneralPage() { )} + + {/* Hidden cache verification — subtle button, dev/admin diagnostic */} +
+ + {cacheCheck && ( +
+ {cacheCheck.consistent ? ( + + + Cache is consistent + + ) : ( +
+
+ + Cache is INCONSISTENT +
+
+ Missing site resources: [ + {cacheCheck.missingSiteResourceIds.join( + ", " + )} + ] +
+
+ Extra site resources: [ + {cacheCheck.extraSiteResourceIds.join(", ")} + ] +
+
+ Missing sites: [ + {cacheCheck.missingSiteIds.join(", ")}] +
+
+ Extra sites: [ + {cacheCheck.extraSiteIds.join(", ")}] +
+
+ )} +
+ )} +
); } diff --git a/src/components/AuthPageBrandingForm.tsx b/src/components/AuthPageBrandingForm.tsx index ca49a50ae..b5ededd91 100644 --- a/src/components/AuthPageBrandingForm.tsx +++ b/src/components/AuthPageBrandingForm.tsx @@ -98,16 +98,6 @@ const AuthPageFormSchema = z.object({ let errorMessage = "Unable to verify image URL. Please check that the URL is accessible and points to an image file."; - if ( - error instanceof TypeError && - error.message.includes("fetch") - ) { - errorMessage = - "Network error: Unable to reach the URL. Please check your internet connection and verify the URL is correct."; - } else if (error instanceof Error) { - errorMessage = `Error verifying URL: ${error.message}`; - } - ctx.addIssue({ code: "custom", message: errorMessage