diff --git a/server/routers/siteResource/createSiteResource.ts b/server/routers/siteResource/createSiteResource.ts index b9494776e..273c7c022 100644 --- a/server/routers/siteResource/createSiteResource.ts +++ b/server/routers/siteResource/createSiteResource.ts @@ -8,6 +8,7 @@ import { SiteResource, siteResources, sites, + siteSiteResources, userSiteResources } from "@server/db"; import { getUniqueSiteResourceName } from "@server/db/names"; @@ -23,7 +24,7 @@ import response from "@server/lib/response"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; import HttpCode from "@server/types/HttpCode"; -import { and, eq } from "drizzle-orm"; +import { and, eq, inArray } from "drizzle-orm"; import { NextFunction, Request, Response } from "express"; import createHttpError from "http-errors"; import { z } from "zod"; @@ -37,7 +38,7 @@ const createSiteResourceSchema = z .strictObject({ name: z.string().min(1).max(255), mode: z.enum(["host", "cidr", "port"]), - siteId: z.int(), + siteIds: z.array(z.int()), // protocol: z.enum(["tcp", "udp"]).optional(), // proxyPort: z.int().positive().optional(), // destinationPort: z.int().positive().optional(), @@ -159,7 +160,7 @@ export async function createSiteResource( const { orgId } = parsedParams.data; const { name, - siteId, + siteIds, mode, // protocol, // proxyPort, @@ -178,14 +179,14 @@ export async function createSiteResource( } = parsedBody.data; // Verify the site exists and belongs to the org - const [site] = await db + const sitesToAssign = await db .select() .from(sites) - .where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId))) + .where(and(inArray(sites.siteId, siteIds), eq(sites.orgId, orgId))) .limit(1); - if (!site) { - return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); + if (sitesToAssign.length !== siteIds.length) { + return next(createHttpError(HttpCode.NOT_FOUND, "Some site not found")); } const [org] = await db @@ -289,7 +290,6 @@ export async function createSiteResource( await db.transaction(async (trx) => { // Create the site resource const insertValues: typeof siteResources.$inferInsert = { - siteId, niceId, orgId, name, @@ -317,6 +317,13 @@ export async function createSiteResource( //////////////////// update the associations //////////////////// + for (const siteId of siteIds) { + await trx.insert(siteSiteResources).values({ + siteId: siteId, + siteResourceId: siteResourceId + }); + } + const [adminRole] = await trx .select() .from(roles) @@ -359,17 +366,18 @@ export async function createSiteResource( ); } - const [newt] = await trx - .select() - .from(newts) - .where(eq(newts.siteId, site.siteId)) - .limit(1); + // Not sure what this is doing?? + // const [newt] = await trx + // .select() + // .from(newts) + // .where(eq(newts.siteId, site.siteId)) + // .limit(1); - if (!newt) { - return next( - createHttpError(HttpCode.NOT_FOUND, "Newt not found") - ); - } + // if (!newt) { + // return next( + // createHttpError(HttpCode.NOT_FOUND, "Newt not found") + // ); + // } await rebuildClientAssociationsFromSiteResource( newSiteResource, @@ -387,7 +395,7 @@ export async function createSiteResource( } logger.info( - `Created site resource ${newSiteResource.siteResourceId} for site ${siteId}` + `Created site resource ${newSiteResource.siteResourceId} for org ${orgId}` ); return response(res, { diff --git a/server/routers/siteResource/updateSiteResource.ts b/server/routers/siteResource/updateSiteResource.ts index 596ed9a3f..f22c5a047 100644 --- a/server/routers/siteResource/updateSiteResource.ts +++ b/server/routers/siteResource/updateSiteResource.ts @@ -9,6 +9,7 @@ import { roles, roleSiteResources, sites, + siteSiteResources, Transaction, userSiteResources } from "@server/db"; @@ -16,7 +17,7 @@ import { siteResources, SiteResource } from "@server/db"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; -import { eq, and, ne } from "drizzle-orm"; +import { eq, and, ne, inArray } from "drizzle-orm"; import { fromError } from "zod-validation-error"; import logger from "@server/logger"; import { OpenAPITags, registry } from "@server/openApi"; @@ -42,7 +43,7 @@ const updateSiteResourceParamsSchema = z.strictObject({ const updateSiteResourceSchema = z .strictObject({ name: z.string().min(1).max(255).optional(), - siteId: z.int(), + siteIds: z.array(z.int()), // niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(), // mode: z.enum(["host", "cidr", "port"]).optional(), mode: z.enum(["host", "cidr"]).optional(), @@ -166,7 +167,7 @@ export async function updateSiteResource( const { siteResourceId } = parsedParams.data; const { name, - siteId, // because it can change + siteIds, // because it can change mode, destination, alias, @@ -181,16 +182,6 @@ export async function updateSiteResource( authDaemonMode } = parsedBody.data; - const [site] = await db - .select() - .from(sites) - .where(eq(sites.siteId, siteId)) - .limit(1); - - if (!site) { - return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); - } - // Check if site resource exists const [existingSiteResource] = await db .select() @@ -230,6 +221,24 @@ export async function updateSiteResource( ); } + // Verify the site exists and belongs to the org + const sitesToAssign = await db + .select() + .from(sites) + .where( + and( + inArray(sites.siteId, siteIds), + eq(sites.orgId, existingSiteResource.orgId) + ) + ) + .limit(1); + + if (sitesToAssign.length !== siteIds.length) { + return next( + createHttpError(HttpCode.NOT_FOUND, "Some site not found") + ); + } + // Only check if destination is an IP address const isIp = z .union([z.ipv4(), z.ipv6()]) @@ -247,25 +256,20 @@ export async function updateSiteResource( ); } - let existingSite = site; - let siteChanged = false; - if (existingSiteResource.siteId !== siteId) { - siteChanged = true; - // get the existing site - [existingSite] = await db - .select() - .from(sites) - .where(eq(sites.siteId, existingSiteResource.siteId)) - .limit(1); + let sitesChanged = false; + const existingSiteIds = await db + .select() + .from(siteSiteResources) + .where(eq(siteSiteResources.siteResourceId, siteResourceId)); - if (!existingSite) { - return next( - createHttpError( - HttpCode.NOT_FOUND, - "Existing site not found" - ) - ); - } + const existingSiteIdSet = new Set(existingSiteIds.map((s) => s.siteId)); + const newSiteIdSet = new Set(siteIds); + + if ( + existingSiteIdSet.size !== newSiteIdSet.size || + ![...existingSiteIdSet].every((id) => newSiteIdSet.has(id)) + ) { + sitesChanged = true; } // make sure the alias is unique within the org if provided @@ -295,7 +299,7 @@ export async function updateSiteResource( let updatedSiteResource: SiteResource | undefined; await db.transaction(async (trx) => { // if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place - if (siteChanged) { + if (sitesChanged) { // delete the existing site resource await trx .delete(siteResources) @@ -321,7 +325,8 @@ export async function updateSiteResource( const sshPamSet = isLicensedSshPam && - (authDaemonPort !== undefined || authDaemonMode !== undefined) + (authDaemonPort !== undefined || + authDaemonMode !== undefined) ? { ...(authDaemonPort !== undefined && { authDaemonPort @@ -335,7 +340,6 @@ export async function updateSiteResource( .update(siteResources) .set({ name: name, - siteId: siteId, mode: mode, destination: destination, enabled: enabled, @@ -423,7 +427,8 @@ export async function updateSiteResource( // Update the site resource const sshPamSet = isLicensedSshPam && - (authDaemonPort !== undefined || authDaemonMode !== undefined) + (authDaemonPort !== undefined || + authDaemonMode !== undefined) ? { ...(authDaemonPort !== undefined && { authDaemonPort @@ -437,7 +442,6 @@ export async function updateSiteResource( .update(siteResources) .set({ name: name, - siteId: siteId, mode: mode, destination: destination, enabled: enabled, @@ -454,6 +458,20 @@ export async function updateSiteResource( //////////////////// update the associations //////////////////// + // delete the site - site resources associations + await trx + .delete(siteSiteResources) + .where( + eq(siteSiteResources.siteResourceId, siteResourceId) + ); + + for (const siteId of siteIds) { + await trx.insert(siteSiteResources).values({ + siteId: siteId, + siteResourceId: siteResourceId + }); + } + await trx .delete(clientSiteResources) .where( @@ -524,13 +542,16 @@ export async function updateSiteResource( } logger.info( - `Updated site resource ${siteResourceId} for site ${siteId}` + `Updated site resource ${siteResourceId}` ); await handleMessagingForUpdatedSiteResource( existingSiteResource, updatedSiteResource, - { siteId: site.siteId, orgId: site.orgId }, + siteIds.map((siteId) => ({ + siteId, + orgId: existingSiteResource.orgId + })), trx ); } @@ -557,7 +578,7 @@ export async function updateSiteResource( export async function handleMessagingForUpdatedSiteResource( existingSiteResource: SiteResource | undefined, updatedSiteResource: SiteResource, - site: { siteId: number; orgId: string }, + sites: { siteId: number; orgId: string }[], trx: Transaction ) { logger.debug( @@ -594,101 +615,116 @@ export async function handleMessagingForUpdatedSiteResource( // if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all if (destinationChanged || aliasChanged || portRangesChanged) { - const [newt] = await trx - .select() - .from(newts) - .where(eq(newts.siteId, site.siteId)) - .limit(1); - - if (!newt) { - throw new Error( - "Newt not found for site during site resource update" - ); - } - - // Only update targets on newt if destination changed - if (destinationChanged || portRangesChanged) { - const oldTargets = generateSubnetProxyTargets( - existingSiteResource, - mergedAllClients - ); - const newTargets = generateSubnetProxyTargets( - updatedSiteResource, - mergedAllClients - ); - - await updateTargets(newt.newtId, { - oldTargets: oldTargets, - newTargets: newTargets - }, newt.version); - } - - const olmJobs: Promise[] = []; - for (const client of mergedAllClients) { - // does this client have access to another resource on this site that has the same destination still? if so we dont want to remove it from their olm yet - // todo: optimize this query if needed - const oldDestinationStillInUseSites = await trx + for (const site of sites) { + const [newt] = await trx .select() - .from(siteResources) - .innerJoin( - clientSiteResourcesAssociationsCache, - eq( - clientSiteResourcesAssociationsCache.siteResourceId, - siteResources.siteResourceId - ) - ) - .where( - and( - eq( - clientSiteResourcesAssociationsCache.clientId, - client.clientId - ), - eq(siteResources.siteId, site.siteId), - eq( - siteResources.destination, - existingSiteResource.destination - ), - ne( - siteResources.siteResourceId, - existingSiteResource.siteResourceId - ) - ) + .from(newts) + .where(eq(newts.siteId, site.siteId)) + .limit(1); + + if (!newt) { + throw new Error( + "Newt not found for site during site resource update" + ); + } + + // Only update targets on newt if destination changed + if (destinationChanged || portRangesChanged) { + const oldTargets = generateSubnetProxyTargets( + existingSiteResource, + mergedAllClients + ); + const newTargets = generateSubnetProxyTargets( + updatedSiteResource, + mergedAllClients ); - const oldDestinationStillInUseByASite = - oldDestinationStillInUseSites.length > 0; + await updateTargets( + newt.newtId, + { + oldTargets: oldTargets, + newTargets: newTargets + }, + newt.version + ); + } - // we also need to update the remote subnets on the olms for each client that has access to this site - olmJobs.push( - updatePeerData( - client.clientId, - updatedSiteResource.siteId, - destinationChanged - ? { - oldRemoteSubnets: !oldDestinationStillInUseByASite - ? generateRemoteSubnets([ - existingSiteResource - ]) - : [], - newRemoteSubnets: generateRemoteSubnets([ - updatedSiteResource - ]) - } - : undefined, - aliasChanged - ? { - oldAliases: generateAliasConfig([ - existingSiteResource - ]), - newAliases: generateAliasConfig([ - updatedSiteResource - ]) - } - : undefined - ) - ); + const olmJobs: Promise[] = []; + for (const client of mergedAllClients) { + // does this client have access to another resource on this site that has the same destination still? if so we dont want to remove it from their olm yet + // todo: optimize this query if needed + const oldDestinationStillInUseSites = await trx + .select() + .from(siteResources) + .innerJoin( + clientSiteResourcesAssociationsCache, + eq( + clientSiteResourcesAssociationsCache.siteResourceId, + siteResources.siteResourceId + ) + ) + .innerJoin( + siteSiteResources, + eq( + siteSiteResources.siteResourceId, + siteResources.siteResourceId + ) + ) + .where( + and( + eq( + clientSiteResourcesAssociationsCache.clientId, + client.clientId + ), + eq(siteSiteResources.siteId, site.siteId), + eq( + siteResources.destination, + existingSiteResource.destination + ), + ne( + siteResources.siteResourceId, + existingSiteResource.siteResourceId + ) + ) + ); + + + const oldDestinationStillInUseByASite = + oldDestinationStillInUseSites.length > 0; + + // we also need to update the remote subnets on the olms for each client that has access to this site + olmJobs.push( + updatePeerData( + client.clientId, + site.siteId, + destinationChanged + ? { + oldRemoteSubnets: + !oldDestinationStillInUseByASite + ? generateRemoteSubnets([ + existingSiteResource + ]) + : [], + newRemoteSubnets: generateRemoteSubnets([ + updatedSiteResource + ]) + } + : undefined, + aliasChanged + ? { + oldAliases: generateAliasConfig([ + existingSiteResource + ]), + newAliases: generateAliasConfig([ + updatedSiteResource + ]) + } + : undefined + ) + ); + } + + await Promise.all(olmJobs); } - - await Promise.all(olmJobs); } }