Adjust create and update to be many to one

This commit is contained in:
Owen
2026-03-18 20:54:49 -07:00
parent 102a235407
commit d8b511b198
2 changed files with 193 additions and 149 deletions

View File

@@ -8,6 +8,7 @@ import {
SiteResource, SiteResource,
siteResources, siteResources,
sites, sites,
siteSiteResources,
userSiteResources userSiteResources
} from "@server/db"; } from "@server/db";
import { getUniqueSiteResourceName } from "@server/db/names"; import { getUniqueSiteResourceName } from "@server/db/names";
@@ -23,7 +24,7 @@ import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { and, eq } from "drizzle-orm"; import { and, eq, inArray } from "drizzle-orm";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { z } from "zod"; import { z } from "zod";
@@ -37,7 +38,7 @@ const createSiteResourceSchema = z
.strictObject({ .strictObject({
name: z.string().min(1).max(255), name: z.string().min(1).max(255),
mode: z.enum(["host", "cidr", "port"]), mode: z.enum(["host", "cidr", "port"]),
siteId: z.int(), siteIds: z.array(z.int()),
// protocol: z.enum(["tcp", "udp"]).optional(), // protocol: z.enum(["tcp", "udp"]).optional(),
// proxyPort: z.int().positive().optional(), // proxyPort: z.int().positive().optional(),
// destinationPort: z.int().positive().optional(), // destinationPort: z.int().positive().optional(),
@@ -159,7 +160,7 @@ export async function createSiteResource(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
const { const {
name, name,
siteId, siteIds,
mode, mode,
// protocol, // protocol,
// proxyPort, // proxyPort,
@@ -178,14 +179,14 @@ export async function createSiteResource(
} = parsedBody.data; } = parsedBody.data;
// Verify the site exists and belongs to the org // Verify the site exists and belongs to the org
const [site] = await db const sitesToAssign = await db
.select() .select()
.from(sites) .from(sites)
.where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId))) .where(and(inArray(sites.siteId, siteIds), eq(sites.orgId, orgId)))
.limit(1); .limit(1);
if (!site) { if (sitesToAssign.length !== siteIds.length) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); return next(createHttpError(HttpCode.NOT_FOUND, "Some site not found"));
} }
const [org] = await db const [org] = await db
@@ -289,7 +290,6 @@ export async function createSiteResource(
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// Create the site resource // Create the site resource
const insertValues: typeof siteResources.$inferInsert = { const insertValues: typeof siteResources.$inferInsert = {
siteId,
niceId, niceId,
orgId, orgId,
name, name,
@@ -317,6 +317,13 @@ export async function createSiteResource(
//////////////////// update the associations //////////////////// //////////////////// update the associations ////////////////////
for (const siteId of siteIds) {
await trx.insert(siteSiteResources).values({
siteId: siteId,
siteResourceId: siteResourceId
});
}
const [adminRole] = await trx const [adminRole] = await trx
.select() .select()
.from(roles) .from(roles)
@@ -359,17 +366,18 @@ export async function createSiteResource(
); );
} }
const [newt] = await trx // Not sure what this is doing??
.select() // const [newt] = await trx
.from(newts) // .select()
.where(eq(newts.siteId, site.siteId)) // .from(newts)
.limit(1); // .where(eq(newts.siteId, site.siteId))
// .limit(1);
if (!newt) { // if (!newt) {
return next( // return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found") // createHttpError(HttpCode.NOT_FOUND, "Newt not found")
); // );
} // }
await rebuildClientAssociationsFromSiteResource( await rebuildClientAssociationsFromSiteResource(
newSiteResource, newSiteResource,
@@ -387,7 +395,7 @@ export async function createSiteResource(
} }
logger.info( logger.info(
`Created site resource ${newSiteResource.siteResourceId} for site ${siteId}` `Created site resource ${newSiteResource.siteResourceId} for org ${orgId}`
); );
return response(res, { return response(res, {

View File

@@ -9,6 +9,7 @@ import {
roles, roles,
roleSiteResources, roleSiteResources,
sites, sites,
siteSiteResources,
Transaction, Transaction,
userSiteResources userSiteResources
} from "@server/db"; } from "@server/db";
@@ -16,7 +17,7 @@ import { siteResources, SiteResource } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { eq, and, ne } from "drizzle-orm"; import { eq, and, ne, inArray } from "drizzle-orm";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
@@ -42,7 +43,7 @@ const updateSiteResourceParamsSchema = z.strictObject({
const updateSiteResourceSchema = z const updateSiteResourceSchema = z
.strictObject({ .strictObject({
name: z.string().min(1).max(255).optional(), name: z.string().min(1).max(255).optional(),
siteId: z.int(), siteIds: z.array(z.int()),
// niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(), // niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(),
// mode: z.enum(["host", "cidr", "port"]).optional(), // mode: z.enum(["host", "cidr", "port"]).optional(),
mode: z.enum(["host", "cidr"]).optional(), mode: z.enum(["host", "cidr"]).optional(),
@@ -166,7 +167,7 @@ export async function updateSiteResource(
const { siteResourceId } = parsedParams.data; const { siteResourceId } = parsedParams.data;
const { const {
name, name,
siteId, // because it can change siteIds, // because it can change
mode, mode,
destination, destination,
alias, alias,
@@ -181,16 +182,6 @@ export async function updateSiteResource(
authDaemonMode authDaemonMode
} = parsedBody.data; } = parsedBody.data;
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// Check if site resource exists // Check if site resource exists
const [existingSiteResource] = await db const [existingSiteResource] = await db
.select() .select()
@@ -230,6 +221,24 @@ export async function updateSiteResource(
); );
} }
// Verify the site exists and belongs to the org
const sitesToAssign = await db
.select()
.from(sites)
.where(
and(
inArray(sites.siteId, siteIds),
eq(sites.orgId, existingSiteResource.orgId)
)
)
.limit(1);
if (sitesToAssign.length !== siteIds.length) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Some site not found")
);
}
// Only check if destination is an IP address // Only check if destination is an IP address
const isIp = z const isIp = z
.union([z.ipv4(), z.ipv6()]) .union([z.ipv4(), z.ipv6()])
@@ -247,25 +256,20 @@ export async function updateSiteResource(
); );
} }
let existingSite = site; let sitesChanged = false;
let siteChanged = false; const existingSiteIds = await db
if (existingSiteResource.siteId !== siteId) {
siteChanged = true;
// get the existing site
[existingSite] = await db
.select() .select()
.from(sites) .from(siteSiteResources)
.where(eq(sites.siteId, existingSiteResource.siteId)) .where(eq(siteSiteResources.siteResourceId, siteResourceId));
.limit(1);
if (!existingSite) { const existingSiteIdSet = new Set(existingSiteIds.map((s) => s.siteId));
return next( const newSiteIdSet = new Set(siteIds);
createHttpError(
HttpCode.NOT_FOUND, if (
"Existing site not found" existingSiteIdSet.size !== newSiteIdSet.size ||
) ![...existingSiteIdSet].every((id) => newSiteIdSet.has(id))
); ) {
} sitesChanged = true;
} }
// make sure the alias is unique within the org if provided // make sure the alias is unique within the org if provided
@@ -295,7 +299,7 @@ export async function updateSiteResource(
let updatedSiteResource: SiteResource | undefined; let updatedSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place // if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place
if (siteChanged) { if (sitesChanged) {
// delete the existing site resource // delete the existing site resource
await trx await trx
.delete(siteResources) .delete(siteResources)
@@ -321,7 +325,8 @@ export async function updateSiteResource(
const sshPamSet = const sshPamSet =
isLicensedSshPam && isLicensedSshPam &&
(authDaemonPort !== undefined || authDaemonMode !== undefined) (authDaemonPort !== undefined ||
authDaemonMode !== undefined)
? { ? {
...(authDaemonPort !== undefined && { ...(authDaemonPort !== undefined && {
authDaemonPort authDaemonPort
@@ -335,7 +340,6 @@ export async function updateSiteResource(
.update(siteResources) .update(siteResources)
.set({ .set({
name: name, name: name,
siteId: siteId,
mode: mode, mode: mode,
destination: destination, destination: destination,
enabled: enabled, enabled: enabled,
@@ -423,7 +427,8 @@ export async function updateSiteResource(
// Update the site resource // Update the site resource
const sshPamSet = const sshPamSet =
isLicensedSshPam && isLicensedSshPam &&
(authDaemonPort !== undefined || authDaemonMode !== undefined) (authDaemonPort !== undefined ||
authDaemonMode !== undefined)
? { ? {
...(authDaemonPort !== undefined && { ...(authDaemonPort !== undefined && {
authDaemonPort authDaemonPort
@@ -437,7 +442,6 @@ export async function updateSiteResource(
.update(siteResources) .update(siteResources)
.set({ .set({
name: name, name: name,
siteId: siteId,
mode: mode, mode: mode,
destination: destination, destination: destination,
enabled: enabled, enabled: enabled,
@@ -454,6 +458,20 @@ export async function updateSiteResource(
//////////////////// update the associations //////////////////// //////////////////// update the associations ////////////////////
// delete the site - site resources associations
await trx
.delete(siteSiteResources)
.where(
eq(siteSiteResources.siteResourceId, siteResourceId)
);
for (const siteId of siteIds) {
await trx.insert(siteSiteResources).values({
siteId: siteId,
siteResourceId: siteResourceId
});
}
await trx await trx
.delete(clientSiteResources) .delete(clientSiteResources)
.where( .where(
@@ -524,13 +542,16 @@ export async function updateSiteResource(
} }
logger.info( logger.info(
`Updated site resource ${siteResourceId} for site ${siteId}` `Updated site resource ${siteResourceId}`
); );
await handleMessagingForUpdatedSiteResource( await handleMessagingForUpdatedSiteResource(
existingSiteResource, existingSiteResource,
updatedSiteResource, updatedSiteResource,
{ siteId: site.siteId, orgId: site.orgId }, siteIds.map((siteId) => ({
siteId,
orgId: existingSiteResource.orgId
})),
trx trx
); );
} }
@@ -557,7 +578,7 @@ export async function updateSiteResource(
export async function handleMessagingForUpdatedSiteResource( export async function handleMessagingForUpdatedSiteResource(
existingSiteResource: SiteResource | undefined, existingSiteResource: SiteResource | undefined,
updatedSiteResource: SiteResource, updatedSiteResource: SiteResource,
site: { siteId: number; orgId: string }, sites: { siteId: number; orgId: string }[],
trx: Transaction trx: Transaction
) { ) {
logger.debug( logger.debug(
@@ -594,6 +615,7 @@ export async function handleMessagingForUpdatedSiteResource(
// if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all // if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all
if (destinationChanged || aliasChanged || portRangesChanged) { if (destinationChanged || aliasChanged || portRangesChanged) {
for (const site of sites) {
const [newt] = await trx const [newt] = await trx
.select() .select()
.from(newts) .from(newts)
@@ -617,10 +639,14 @@ export async function handleMessagingForUpdatedSiteResource(
mergedAllClients mergedAllClients
); );
await updateTargets(newt.newtId, { await updateTargets(
newt.newtId,
{
oldTargets: oldTargets, oldTargets: oldTargets,
newTargets: newTargets newTargets: newTargets
}, newt.version); },
newt.version
);
} }
const olmJobs: Promise<void>[] = []; const olmJobs: Promise<void>[] = [];
@@ -637,13 +663,20 @@ export async function handleMessagingForUpdatedSiteResource(
siteResources.siteResourceId siteResources.siteResourceId
) )
) )
.innerJoin(
siteSiteResources,
eq(
siteSiteResources.siteResourceId,
siteResources.siteResourceId
)
)
.where( .where(
and( and(
eq( eq(
clientSiteResourcesAssociationsCache.clientId, clientSiteResourcesAssociationsCache.clientId,
client.clientId client.clientId
), ),
eq(siteResources.siteId, site.siteId), eq(siteSiteResources.siteId, site.siteId),
eq( eq(
siteResources.destination, siteResources.destination,
existingSiteResource.destination existingSiteResource.destination
@@ -655,6 +688,7 @@ export async function handleMessagingForUpdatedSiteResource(
) )
); );
const oldDestinationStillInUseByASite = const oldDestinationStillInUseByASite =
oldDestinationStillInUseSites.length > 0; oldDestinationStillInUseSites.length > 0;
@@ -662,10 +696,11 @@ export async function handleMessagingForUpdatedSiteResource(
olmJobs.push( olmJobs.push(
updatePeerData( updatePeerData(
client.clientId, client.clientId,
updatedSiteResource.siteId, site.siteId,
destinationChanged destinationChanged
? { ? {
oldRemoteSubnets: !oldDestinationStillInUseByASite oldRemoteSubnets:
!oldDestinationStillInUseByASite
? generateRemoteSubnets([ ? generateRemoteSubnets([
existingSiteResource existingSiteResource
]) ])
@@ -691,4 +726,5 @@ export async function handleMessagingForUpdatedSiteResource(
await Promise.all(olmJobs); await Promise.all(olmJobs);
} }
}
} }