Build client site resource associations and send messages

This commit is contained in:
Owen
2025-11-19 18:05:42 -05:00
parent 806949879a
commit 937b36e756
36 changed files with 904 additions and 583 deletions

View File

@@ -69,9 +69,9 @@ export async function cleanUpOldLogs(orgId: string, retentionDays: number) {
)
);
logger.debug(
`Cleaned up request audit logs older than ${retentionDays} days`
);
// logger.debug(
// `Cleaned up request audit logs older than ${retentionDays} days`
// );
} catch (error) {
logger.error("Error cleaning up old request audit logs:", error);
}

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients, clientSites } from "@server/db";
import { clients, clientSitesAssociationsCache } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
@@ -70,8 +70,8 @@ export async function deleteClient(
await db.transaction(async (trx) => {
// Delete the client-site associations first
await trx
.delete(clientSites)
.where(eq(clientSites.clientId, clientId));
.delete(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.clientId, clientId));
// Then delete the client itself
await trx.delete(clients).where(eq(clients.clientId, clientId));

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { clients, clientSites } from "@server/db";
import { clients, clientSitesAssociationsCache } from "@server/db";
import { eq, and } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
@@ -29,9 +29,9 @@ async function query(clientId: number) {
// Get the siteIds associated with this client
const sites = await db
.select({ siteId: clientSites.siteId })
.from(clientSites)
.where(eq(clientSites.clientId, clientId));
.select({ siteId: clientSitesAssociationsCache.siteId })
.from(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.clientId, clientId));
// Add the siteIds to the client object
return {

View File

@@ -5,7 +5,7 @@ import {
roleClients,
sites,
userClients,
clientSites
clientSitesAssociationsCache
} from "@server/db";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
@@ -142,14 +142,14 @@ async function getSiteAssociations(clientIds: number[]) {
return db
.select({
clientId: clientSites.clientId,
siteId: clientSites.siteId,
clientId: clientSitesAssociationsCache.clientId,
siteId: clientSitesAssociationsCache.siteId,
siteName: sites.name,
siteNiceId: sites.niceId
})
.from(clientSites)
.leftJoin(sites, eq(clientSites.siteId, sites.siteId))
.where(inArray(clientSites.clientId, clientIds));
.from(clientSitesAssociationsCache)
.leftJoin(sites, eq(clientSitesAssociationsCache.siteId, sites.siteId))
.where(inArray(clientSitesAssociationsCache.clientId, clientIds));
}
type OlmWithUpdateAvailable = Awaited<ReturnType<typeof queryClients>>[0] & {

View File

@@ -3,7 +3,7 @@ import { SubnetProxyTarget } from "@server/lib/ip";
export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
await sendToClient(newtId, {
type: `newt/wg/target/add`,
type: `newt/wg/targets/add`,
data: targets
});
}
@@ -13,7 +13,7 @@ export async function removeTargets(
targets: SubnetProxyTarget[]
) {
await sendToClient(newtId, {
type: `newt/wg/target/remove`,
type: `newt/wg/targets/remove`,
data: targets
});
}
@@ -26,7 +26,7 @@ export async function updateTargets(
}
) {
await sendToClient(newtId, {
type: `newt/wg/target/update`,
type: `newt/wg/targets/update`,
data: targets
});
}

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { Client, db, exitNodes, olms, sites } from "@server/db";
import { clients, clientSites } from "@server/db";
import { clients, clientSitesAssociationsCache } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";

View File

@@ -7,7 +7,7 @@ import {
olms,
Site,
sites,
clientSites,
clientSitesAssociationsCache,
ExitNode
} from "@server/db";
import { db } from "@server/db";
@@ -109,8 +109,8 @@ export async function generateRelayMappings(exitNode: ExitNode) {
// Find all clients associated with this site through clientSites
const clientSitesRes = await db
.select()
.from(clientSites)
.where(eq(clientSites.siteId, site.siteId));
.from(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.siteId, site.siteId));
for (const clientSite of clientSitesRes) {
if (!clientSite.endpoint) {

View File

@@ -6,7 +6,7 @@ import {
olms,
Site,
sites,
clientSites,
clientSitesAssociationsCache,
exitNodes,
ExitNode
} from "@server/db";
@@ -174,11 +174,11 @@ export async function updateAndGenerateEndpointDestinations(
listenPort: sites.listenPort
})
.from(sites)
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
.innerJoin(clientSitesAssociationsCache, eq(sites.siteId, clientSitesAssociationsCache.siteId))
.where(
and(
eq(sites.exitNodeId, exitNode.exitNodeId),
eq(clientSites.clientId, olm.clientId)
eq(clientSitesAssociationsCache.clientId, olm.clientId)
)
);
@@ -189,14 +189,14 @@ export async function updateAndGenerateEndpointDestinations(
);
await db
.update(clientSites)
.update(clientSitesAssociationsCache)
.set({
endpoint: `${ip}:${port}`
})
.where(
and(
eq(clientSites.clientId, olm.clientId),
eq(clientSites.siteId, site.siteId)
eq(clientSitesAssociationsCache.clientId, olm.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
);
}

View File

@@ -11,7 +11,7 @@ import {
Target,
targets
} from "@server/db";
import { clients, clientSites, Newt, sites } from "@server/db";
import { clients, clientSitesAssociationsCache, Newt, sites } from "@server/db";
import { eq, and, inArray } from "drizzle-orm";
import { updatePeer } from "../olm/peers";
import { sendToExitNode } from "#dynamic/lib/exitNodes";
@@ -138,8 +138,8 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
const clientsRes = await db
.select()
.from(clients)
.innerJoin(clientSites, eq(clients.clientId, clientSites.clientId))
.where(eq(clientSites.siteId, siteId));
.innerJoin(clientSitesAssociationsCache, eq(clients.clientId, clientSitesAssociationsCache.clientId))
.where(eq(clientSitesAssociationsCache.siteId, siteId));
// Prepare peers data for the response
const peers = await Promise.all(

View File

@@ -1,6 +1,6 @@
import { NextFunction, Request, Response } from "express";
import { db } from "@server/db";
import { olms, clients, clientSites } from "@server/db";
import { olms, clients, clientSitesAssociationsCache } from "@server/db";
import { eq } from "drizzle-orm";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -57,8 +57,8 @@ export async function deleteUserOlm(
// Delete client-site associations for each associated client
for (const client of associatedClients) {
await trx
.delete(clientSites)
.where(eq(clientSites.clientId, client.clientId));
.delete(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
}
// Delete all associated clients

View File

@@ -12,7 +12,7 @@ import {
users
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, clientSites, exitNodes, Olm, olms, sites } from "@server/db";
import { clients, clientSitesAssociationsCache, exitNodes, Olm, olms, sites } from "@server/db";
import { and, eq, inArray, isNull } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger";
@@ -159,19 +159,19 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// set isRelay to false for all of the client's sites to reset the connection metadata
await db
.update(clientSites)
.update(clientSitesAssociationsCache)
.set({
isRelayed: relay == true
})
.where(eq(clientSites.clientId, client.clientId));
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
}
// Get all sites data
const sitesData = await db
.select()
.from(sites)
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
.where(eq(clientSites.clientId, client.clientId));
.innerJoin(clientSitesAssociationsCache, eq(sites.siteId, clientSitesAssociationsCache.siteId))
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Prepare an array to store site configurations
const siteConfigurations = [];
@@ -225,11 +225,11 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
const [clientSite] = await db
.select()
.from(clientSites)
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSites.clientId, client.clientId),
eq(clientSites.siteId, site.siteId)
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
)
.limit(1);

View File

@@ -1,6 +1,6 @@
import { db, exitNodes, sites } from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, clientSites, Olm } from "@server/db";
import { clients, clientSitesAssociationsCache, Olm } from "@server/db";
import { and, eq } from "drizzle-orm";
import { updatePeer } from "../newt/peers";
import logger from "@server/logger";
@@ -67,14 +67,14 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
}
await db
.update(clientSites)
.update(clientSitesAssociationsCache)
.set({
isRelayed: true
})
.where(
and(
eq(clientSites.clientId, olm.clientId),
eq(clientSites.siteId, siteId)
eq(clientSitesAssociationsCache.clientId, olm.clientId),
eq(clientSitesAssociationsCache.siteId, siteId)
)
);

View File

@@ -8,7 +8,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, and } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations";
const addClientToSiteResourceBodySchema = z
.object({
@@ -136,7 +136,7 @@ export async function addClientToSiteResource(
siteResourceId
});
await rebuildSiteClientAssociations(siteResource, trx);
await rebuildClientAssociations(siteResource, trx);
});
return response(res, {

View File

@@ -9,7 +9,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, and } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations";
const addRoleToSiteResourceBodySchema = z
.object({
@@ -146,7 +146,7 @@ export async function addRoleToSiteResource(
siteResourceId
});
await rebuildSiteClientAssociations(siteResource, trx);
await rebuildClientAssociations(siteResource, trx);
});
return response(res, {

View File

@@ -9,7 +9,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, and } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations";
const addUserToSiteResourceBodySchema = z
.object({
@@ -115,7 +115,7 @@ export async function addUserToSiteResource(
siteResourceId
});
await rebuildSiteClientAssociations(siteResource, trx);
await rebuildClientAssociations(siteResource, trx);
});
return response(res, {

View File

@@ -1,7 +1,14 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, newts, roleResources, roles, roleSiteResources } from "@server/db";
import { siteResources, sites, orgs, SiteResource } from "@server/db";
import {
clientSiteResources,
db,
newts,
roles,
roleSiteResources,
userSiteResources
} from "@server/db";
import { siteResources, sites, SiteResource } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -9,10 +16,8 @@ import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import { addTargets } from "../client/targets";
import { getUniqueSiteResourceName } from "@server/db/names";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { generateSubnetProxyTargets } from "@server/lib/ip";
import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations";
const createSiteResourceParamsSchema = z.strictObject({
siteId: z.string().transform(Number).pipe(z.int().positive()),
@@ -23,12 +28,15 @@ const createSiteResourceSchema = z
.strictObject({
name: z.string().min(1).max(255),
mode: z.enum(["host", "cidr", "port"]),
protocol: z.enum(["tcp", "udp"]).optional(),
// protocol: z.enum(["tcp", "udp"]).optional(),
// proxyPort: z.int().positive().optional(),
// destinationPort: z.int().positive().optional(),
destination: z.string().min(1),
enabled: z.boolean().default(true),
alias: z.string().optional()
alias: z.string().optional(),
userIds: z.array(z.string()),
roleIds: z.array(z.int()),
clientIds: z.array(z.int())
})
.strict()
// .refine(
@@ -138,12 +146,15 @@ export async function createSiteResource(
const {
name,
mode,
protocol,
// protocol,
// proxyPort,
// destinationPort,
destination,
enabled,
alias
alias,
userIds,
roleIds,
clientIds
} = parsedBody.data;
// Verify the site exists and belongs to the org
@@ -194,7 +205,7 @@ export async function createSiteResource(
orgId,
name,
mode,
protocol: mode === "port" ? protocol : null,
// protocol: mode === "port" ? protocol : null,
// proxyPort: mode === "port" ? proxyPort : null,
// destinationPort: mode === "port" ? destinationPort : null,
destination,
@@ -203,6 +214,10 @@ export async function createSiteResource(
})
.returning();
const siteResourceId = newSiteResource.siteResourceId;
//////////////////// update the associations ////////////////////
const [adminRole] = await trx
.select()
.from(roles)
@@ -217,9 +232,34 @@ export async function createSiteResource(
await trx.insert(roleSiteResources).values({
roleId: adminRole.roleId,
siteResourceId: newSiteResource.siteResourceId
siteResourceId: siteResourceId
});
if (roleIds.length > 0) {
await trx
.insert(roleSiteResources)
.values(
roleIds.map((roleId) => ({ roleId, siteResourceId }))
);
}
if (userIds.length > 0) {
await trx
.insert(userSiteResources)
.values(
userIds.map((userId) => ({ userId, siteResourceId }))
);
}
if (clientIds.length > 0) {
await trx.insert(clientSiteResources).values(
clientIds.map((clientId) => ({
clientId,
siteResourceId
}))
);
}
const [newt] = await trx
.select()
.from(newts)
@@ -232,10 +272,10 @@ export async function createSiteResource(
);
}
const targets = await generateSubnetProxyTargets([newSiteResource], trx);
await addTargets(newt.newtId, targets);
// const targets = await generateSubnetProxyTargets([newSiteResource], trx);
// await addTargets(newt.newtId, targets);
await rebuildSiteClientAssociations(newSiteResource, trx); // we need to call this because we added to the admin role
await rebuildClientAssociations(newSiteResource, trx); // we need to call this because we added to the admin role
});
if (!newSiteResource) {

View File

@@ -9,9 +9,7 @@ import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import { removeTargets } from "../client/targets";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { generateSubnetProxyTargets } from "@server/lib/ip";
import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations";
const deleteSiteResourceParamsSchema = z.strictObject({
siteResourceId: z.string().transform(Number).pipe(z.int().positive()),
@@ -108,10 +106,10 @@ export async function deleteSiteResource(
);
}
const targets = await generateSubnetProxyTargets([removedSiteResource], trx);
await removeTargets(newt.newtId, targets);
// const targets = await generateSubnetProxyTargets([removedSiteResource], trx);
// await removeTargets(newt.newtId, targets);
await rebuildSiteClientAssociations(existingSiteResource, trx);
await rebuildClientAssociations(existingSiteResource, trx);
});
logger.info(

View File

@@ -8,7 +8,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, and } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations";
const removeClientFromSiteResourceBodySchema = z
.object({
@@ -142,7 +142,7 @@ export async function removeClientFromSiteResource(
)
);
await rebuildSiteClientAssociations(siteResource, trx);
await rebuildClientAssociations(siteResource, trx);
});
return response(res, {

View File

@@ -9,7 +9,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, and } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations";
const removeRoleFromSiteResourceBodySchema = z
.object({
@@ -151,7 +151,7 @@ export async function removeRoleFromSiteResource(
)
);
await rebuildSiteClientAssociations(siteResource, trx);
await rebuildClientAssociations(siteResource, trx);
});
return response(res, {

View File

@@ -9,7 +9,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, and } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations";
const removeUserFromSiteResourceBodySchema = z
.object({
@@ -121,7 +121,7 @@ export async function removeUserFromSiteResource(
)
);
await rebuildSiteClientAssociations(siteResource, trx);
await rebuildClientAssociations(siteResource, trx);
});
return response(res, {

View File

@@ -8,7 +8,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, inArray } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations";
const setSiteResourceClientsBodySchema = z
.object({
@@ -119,17 +119,12 @@ export async function setSiteResourceClients(
.where(eq(clientSiteResources.siteResourceId, siteResourceId));
if (clientIds.length > 0) {
await Promise.all(
clientIds.map((clientId) =>
trx
.insert(clientSiteResources)
.values({ clientId, siteResourceId })
.returning()
)
);
await trx
.insert(clientSiteResources)
.values(clientIds.map((clientId) => ({ clientId, siteResourceId })));
}
await rebuildSiteClientAssociations(siteResource, trx);
await rebuildClientAssociations(siteResource, trx);
});
return response(res, {

View File

@@ -9,7 +9,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, and, ne, inArray } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations";
const setSiteResourceRolesBodySchema = z
.object({
@@ -141,16 +141,13 @@ export async function setSiteResourceRoles(
);
}
await Promise.all(
roleIds.map((roleId) =>
trx
.insert(roleSiteResources)
.values({ roleId, siteResourceId })
.returning()
)
);
if (roleIds.length > 0) {
await trx
.insert(roleSiteResources)
.values(roleIds.map((roleId) => ({ roleId, siteResourceId })));
}
await rebuildSiteClientAssociations(siteResource, trx);
await rebuildClientAssociations(siteResource, trx);
});
return response(res, {

View File

@@ -9,7 +9,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
import { rebuildClientAssociations } from "@server/lib/rebuildClientAssociations";
const setSiteResourceUsersBodySchema = z
.object({
@@ -96,16 +96,13 @@ export async function setSiteResourceUsers(
.delete(userSiteResources)
.where(eq(userSiteResources.siteResourceId, siteResourceId));
await Promise.all(
userIds.map((userId) =>
trx
.insert(userSiteResources)
.values({ userId, siteResourceId })
.returning()
)
);
if (userIds.length > 0) {
await trx
.insert(userSiteResources)
.values(userIds.map((userId) => ({ userId, siteResourceId })));
}
await rebuildSiteClientAssociations(siteResource, trx);
await rebuildClientAssociations(siteResource, trx);
});
return response(res, {

View File

@@ -1,16 +1,28 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, newts, sites } from "@server/db";
import {
clientSiteResources,
db,
newts,
roles,
roleSiteResources,
sites,
userSiteResources
} from "@server/db";
import { siteResources, SiteResource } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { eq, and } from "drizzle-orm";
import { eq, and, ne } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import { updateTargets } from "@server/routers/client/targets";
import { generateSubnetProxyTargets } from "@server/lib/ip";
import { generateSingleSubnetProxyTargets } from "@server/lib/ip";
import {
getClientSiteResourceAccess,
rebuildClientAssociations
} from "@server/lib/rebuildClientAssociations";
const updateSiteResourceParamsSchema = z.strictObject({
siteResourceId: z.string().transform(Number).pipe(z.int().positive()),
@@ -23,12 +35,15 @@ const updateSiteResourceSchema = z
name: z.string().min(1).max(255).optional(),
// mode: z.enum(["host", "cidr", "port"]).optional(),
mode: z.enum(["host", "cidr"]).optional(),
protocol: z.enum(["tcp", "udp"]).nullish(),
// protocol: z.enum(["tcp", "udp"]).nullish(),
// proxyPort: z.int().positive().nullish(),
// destinationPort: z.int().positive().nullish(),
destination: z.string().min(1).optional(),
enabled: z.boolean().optional(),
alias: z.string().nullish()
alias: z.string().nullish(),
userIds: z.array(z.string()),
roleIds: z.array(z.int()),
clientIds: z.array(z.int())
})
.strict();
@@ -82,7 +97,16 @@ export async function updateSiteResource(
}
const { siteResourceId, siteId, orgId } = parsedParams.data;
const updateData = parsedBody.data;
const {
name,
mode,
destination,
alias,
enabled,
userIds,
roleIds,
clientIds
} = parsedBody.data;
const [site] = await db
.select()
@@ -113,85 +137,131 @@ export async function updateSiteResource(
);
}
// Determine the final mode and validate port mode requirements
const finalMode = updateData.mode || existingSiteResource.mode;
const finalProtocol = updateData.protocol !== undefined ? updateData.protocol : existingSiteResource.protocol;
// const finalProxyPort = updateData.proxyPort !== undefined ? updateData.proxyPort : existingSiteResource.proxyPort;
// const finalDestinationPort = updateData.destinationPort !== undefined ? updateData.destinationPort : existingSiteResource.destinationPort;
// Prepare update data
const updateValues: any = {};
if (updateData.name !== undefined) updateValues.name = updateData.name;
if (updateData.mode !== undefined) updateValues.mode = updateData.mode;
if (updateData.destination !== undefined)
updateValues.destination = updateData.destination;
if (updateData.enabled !== undefined)
updateValues.enabled = updateData.enabled;
// Handle nullish fields (can be undefined, null, or a value)
if (updateData.alias !== undefined) {
updateValues.alias =
updateData.alias && updateData.alias.trim()
? updateData.alias
: null;
}
// Handle port mode fields - include in update if explicitly provided (null or value) or if mode changed
// const isModeChangingFromPort =
// existingSiteResource.mode === "port" &&
// updateData.mode &&
// updateData.mode !== "port";
// if (updateData.protocol !== undefined || isModeChangingFromPort) {
// updateValues.protocol = finalMode === "port" ? finalProtocol : null;
// }
// if (updateData.proxyPort !== undefined || isModeChangingFromPort) {
// updateValues.proxyPort =
// finalMode === "port" ? finalProxyPort : null;
// }
// if (
// updateData.destinationPort !== undefined ||
// isModeChangingFromPort
// ) {
// updateValues.destinationPort =
// finalMode === "port" ? finalDestinationPort : null;
// }
// Update the site resource
const [updatedSiteResource] = await db
.update(siteResources)
.set(updateValues)
.where(
and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
let updatedSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => {
// Update the site resource
[updatedSiteResource] = await trx
.update(siteResources)
.set({
name: name,
mode: mode,
destination: destination,
enabled: enabled,
alias: alias && alias.trim() ? alias : null
})
.where(
and(
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
)
)
)
.returning();
.returning();
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
//////////////////// update the associations ////////////////////
if (!newt) {
return next(createHttpError(HttpCode.NOT_FOUND, "Newt not found"));
}
await trx
.delete(clientSiteResources)
.where(eq(clientSiteResources.siteResourceId, siteResourceId));
const oldTargets = await generateSubnetProxyTargets([existingSiteResource]);
const newTargets = await generateSubnetProxyTargets([updatedSiteResource]);
if (clientIds.length > 0) {
await trx.insert(clientSiteResources).values(
clientIds.map((clientId) => ({
clientId,
siteResourceId
}))
);
}
await updateTargets(newt.newtId, {
oldTargets: oldTargets,
newTargets: newTargets
await trx
.delete(userSiteResources)
.where(eq(userSiteResources.siteResourceId, siteResourceId));
if (userIds.length > 0) {
await trx
.insert(userSiteResources)
.values(
userIds.map((userId) => ({ userId, siteResourceId }))
);
}
// Get all admin role IDs for this org to exclude from deletion
const adminRoles = await trx
.select()
.from(roles)
.where(
and(
eq(roles.isAdmin, true),
eq(roles.orgId, updatedSiteResource.orgId)
)
);
const adminRoleIds = adminRoles.map((role) => role.roleId);
if (adminRoleIds.length > 0) {
await trx.delete(roleSiteResources).where(
and(
eq(roleSiteResources.siteResourceId, siteResourceId),
ne(roleSiteResources.roleId, adminRoleIds[0]) // delete all but the admin role
)
);
} else {
await trx
.delete(roleSiteResources)
.where(
eq(roleSiteResources.siteResourceId, siteResourceId)
);
}
if (roleIds.length > 0) {
await trx
.insert(roleSiteResources)
.values(
roleIds.map((roleId) => ({ roleId, siteResourceId }))
);
}
const { mergedAllClients } = await rebuildClientAssociations(
updatedSiteResource,
trx
); // we need to call this because we added to the admin role
// after everything is rebuilt above we still need to update the targets if the destination changed
if (
existingSiteResource.destination !==
updatedSiteResource.destination
) {
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found")
);
}
const oldTargets = generateSingleSubnetProxyTargets(
existingSiteResource,
mergedAllClients
);
const newTargets = generateSingleSubnetProxyTargets(
updatedSiteResource,
mergedAllClients
);
await updateTargets(newt.newtId, {
oldTargets: oldTargets,
newTargets: newTargets
});
}
logger.info(
`Updated site resource ${siteResourceId} for site ${siteId}`
);
});
logger.info(
`Updated site resource ${siteResourceId} for site ${siteId}`
);
return response(res, {
data: updatedSiteResource,
success: true,