Compare commits

...

6 Commits

Author SHA1 Message Date
Owen
2093bb5357 Remove siteSiteResources 2026-03-19 21:44:59 -07:00
Owen
6f2e37948c Its many to one now 2026-03-19 21:30:00 -07:00
Owen
b7421e47cc Switch to using networks 2026-03-19 21:22:04 -07:00
Owen
7cbe3d42a1 Working on refactoring 2026-03-19 12:10:04 -07:00
Owen
d8b511b198 Adjust create and update to be many to one 2026-03-18 20:54:49 -07:00
Owen
102a235407 Adjust schema for many to one site resources 2026-03-18 20:54:38 -07:00
8 changed files with 448 additions and 219 deletions

View File

@@ -216,12 +216,18 @@ export const exitNodes = pgTable("exitNodes", {
export const siteResources = pgTable("siteResources", { export const siteResources = pgTable("siteResources", {
// this is for the clients // this is for the clients
siteResourceId: serial("siteResourceId").primaryKey(), siteResourceId: serial("siteResourceId").primaryKey(),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
orgId: varchar("orgId") orgId: varchar("orgId")
.notNull() .notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }), .references(() => orgs.orgId, { onDelete: "cascade" }),
networkId: integer("networkId").references(() => networks.networkId, {
onDelete: "set null"
}),
defaultNetworkId: integer("defaultNetworkId").references(
() => networks.networkId,
{
onDelete: "restrict"
}
),
niceId: varchar("niceId").notNull(), niceId: varchar("niceId").notNull(),
name: varchar("name").notNull(), name: varchar("name").notNull(),
mode: varchar("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port" mode: varchar("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port"
@@ -241,6 +247,32 @@ export const siteResources = pgTable("siteResources", {
.default("site") .default("site")
}); });
export const networks = pgTable("networks", {
networkId: serial("networkId").primaryKey(),
niceId: text("niceId"),
name: text("name"),
scope: varchar("scope")
.$type<"global" | "resource">()
.notNull()
.default("global"),
orgId: varchar("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
})
.notNull()
});
export const siteNetworks = pgTable("siteNetworks", {
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, {
onDelete: "cascade"
}),
networkId: integer("networkId")
.notNull()
.references(() => networks.networkId, { onDelete: "cascade" })
});
export const clientSiteResources = pgTable("clientSiteResources", { export const clientSiteResources = pgTable("clientSiteResources", {
clientId: integer("clientId") clientId: integer("clientId")
.notNull() .notNull()
@@ -1074,3 +1106,4 @@ export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
export type RoundTripMessageTracker = InferSelectModel< export type RoundTripMessageTracker = InferSelectModel<
typeof roundTripMessageTracker typeof roundTripMessageTracker
>; >;
export type Network = InferSelectModel<typeof networks>;

View File

@@ -82,6 +82,9 @@ export const sites = sqliteTable("sites", {
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null" onDelete: "set null"
}), }),
networkId: integer("networkId").references(() => networks.networkId, {
onDelete: "set null"
}),
name: text("name").notNull(), name: text("name").notNull(),
pubKey: text("pubKey"), pubKey: text("pubKey"),
subnet: text("subnet"), subnet: text("subnet"),
@@ -239,12 +242,16 @@ export const siteResources = sqliteTable("siteResources", {
siteResourceId: integer("siteResourceId").primaryKey({ siteResourceId: integer("siteResourceId").primaryKey({
autoIncrement: true autoIncrement: true
}), }),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
orgId: text("orgId") orgId: text("orgId")
.notNull() .notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }), .references(() => orgs.orgId, { onDelete: "cascade" }),
networkId: integer("networkId").references(() => networks.networkId, {
onDelete: "set null"
}),
defaultNetworkId: integer("defaultNetworkId").references(
() => networks.networkId,
{ onDelete: "restrict" }
),
niceId: text("niceId").notNull(), niceId: text("niceId").notNull(),
name: text("name").notNull(), name: text("name").notNull(),
mode: text("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port" mode: text("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port"
@@ -266,6 +273,30 @@ export const siteResources = sqliteTable("siteResources", {
.default("site") .default("site")
}); });
export const networks = sqliteTable("networks", {
networkId: integer("networkId").primaryKey({ autoIncrement: true }),
niceId: text("niceId"),
name: text("name"),
scope: text("scope")
.$type<"global" | "resource">()
.notNull()
.default("global"),
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" })
});
export const siteNetworks = sqliteTable("siteNetworks", {
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, {
onDelete: "cascade"
}),
networkId: integer("networkId")
.notNull()
.references(() => networks.networkId, { onDelete: "cascade" })
});
export const clientSiteResources = sqliteTable("clientSiteResources", { export const clientSiteResources = sqliteTable("clientSiteResources", {
clientId: integer("clientId") clientId: integer("clientId")
.notNull() .notNull()
@@ -1158,6 +1189,7 @@ export type ApiKey = InferSelectModel<typeof apiKeys>;
export type ApiKeyAction = InferSelectModel<typeof apiKeyActions>; export type ApiKeyAction = InferSelectModel<typeof apiKeyActions>;
export type ApiKeyOrg = InferSelectModel<typeof apiKeyOrg>; export type ApiKeyOrg = InferSelectModel<typeof apiKeyOrg>;
export type SiteResource = InferSelectModel<typeof siteResources>; export type SiteResource = InferSelectModel<typeof siteResources>;
export type Network = InferSelectModel<typeof networks>;
export type OrgDomains = InferSelectModel<typeof orgDomains>; export type OrgDomains = InferSelectModel<typeof orgDomains>;
export type SetupToken = InferSelectModel<typeof setupTokens>; export type SetupToken = InferSelectModel<typeof setupTokens>;
export type HostMeta = InferSelectModel<typeof hostMeta>; export type HostMeta = InferSelectModel<typeof hostMeta>;

View File

@@ -121,8 +121,8 @@ export async function applyBlueprint({
for (const result of clientResourcesResults) { for (const result of clientResourcesResults) {
if ( if (
result.oldSiteResource && result.oldSiteResource &&
result.oldSiteResource.siteId != JSON.stringify(result.newSites?.sort()) !==
result.newSiteResource.siteId JSON.stringify(result.oldSites?.sort())
) { ) {
// query existing associations // query existing associations
const existingRoleIds = await trx const existingRoleIds = await trx
@@ -222,13 +222,15 @@ export async function applyBlueprint({
trx trx
); );
} else { } else {
const [newSite] = await trx let good = true;
for (const newSite of result.newSites) {
const [site] = await trx
.select() .select()
.from(sites) .from(sites)
.innerJoin(newts, eq(sites.siteId, newts.siteId)) .innerJoin(newts, eq(sites.siteId, newts.siteId))
.where( .where(
and( and(
eq(sites.siteId, result.newSiteResource.siteId), eq(sites.siteId, newSite.siteId),
eq(sites.orgId, orgId), eq(sites.orgId, orgId),
eq(sites.type, "newt"), eq(sites.type, "newt"),
isNotNull(sites.pubKey) isNotNull(sites.pubKey)
@@ -236,24 +238,30 @@ export async function applyBlueprint({
) )
.limit(1); .limit(1);
if (!newSite) { if (!site) {
logger.debug( logger.debug(
`No newt site found for client resource ${result.newSiteResource.siteResourceId}, skipping target update` `No newt sites found for client resource ${result.newSiteResource.siteResourceId}, skipping target update`
); );
continue; good = false;
break;
} }
logger.debug( logger.debug(
`Updating client resource ${result.newSiteResource.siteResourceId} on site ${newSite.sites.siteId}` `Updating client resource ${result.newSiteResource.siteResourceId} on site ${newSite.siteId}`
); );
}
if (!good) {
continue;
}
await handleMessagingForUpdatedSiteResource( await handleMessagingForUpdatedSiteResource(
result.oldSiteResource, result.oldSiteResource,
result.newSiteResource, result.newSiteResource,
{ result.newSites.map((site) => ({
siteId: newSite.sites.siteId, siteId: site.siteId,
orgId: newSite.sites.orgId orgId: result.newSiteResource.orgId
}, })),
trx trx
); );
} }

View File

@@ -3,12 +3,15 @@ import {
clientSiteResources, clientSiteResources,
roles, roles,
roleSiteResources, roleSiteResources,
Site,
SiteResource, SiteResource,
siteNetworks,
siteResources, siteResources,
Transaction, Transaction,
userOrgs, userOrgs,
users, users,
userSiteResources userSiteResources,
networks
} from "@server/db"; } from "@server/db";
import { sites } from "@server/db"; import { sites } from "@server/db";
import { eq, and, ne, inArray, or } from "drizzle-orm"; import { eq, and, ne, inArray, or } from "drizzle-orm";
@@ -19,6 +22,8 @@ import { getNextAvailableAliasAddress } from "../ip";
export type ClientResourcesResults = { export type ClientResourcesResults = {
newSiteResource: SiteResource; newSiteResource: SiteResource;
oldSiteResource?: SiteResource; oldSiteResource?: SiteResource;
newSites: { siteId: number }[];
oldSites: { siteId: number }[];
}[]; }[];
export async function updateClientResources( export async function updateClientResources(
@@ -43,12 +48,21 @@ export async function updateClientResources(
) )
.limit(1); .limit(1);
const existingSiteIds = existingResource?.networkId
? await trx
.select({ siteId: sites.siteId })
.from(siteNetworks)
.where(eq(siteNetworks.networkId, existingResource.networkId))
: [];
let allSites: { siteId: number }[] = [];
if (resourceData.site) {
let siteSingle;
const resourceSiteId = resourceData.site; const resourceSiteId = resourceData.site;
let site;
if (resourceSiteId) { if (resourceSiteId) {
// Look up site by niceId // Look up site by niceId
[site] = await trx [siteSingle] = await trx
.select({ siteId: sites.siteId }) .select({ siteId: sites.siteId })
.from(sites) .from(sites)
.where( .where(
@@ -60,20 +74,45 @@ export async function updateClientResources(
.limit(1); .limit(1);
} else if (siteId) { } else if (siteId) {
// Use the provided siteId directly, but verify it belongs to the org // Use the provided siteId directly, but verify it belongs to the org
[site] = await trx [siteSingle] = await trx
.select({ siteId: sites.siteId }) .select({ siteId: sites.siteId })
.from(sites) .from(sites)
.where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId))) .where(
and(eq(sites.siteId, siteId), eq(sites.orgId, orgId))
)
.limit(1); .limit(1);
} else { } else {
throw new Error(`Target site is required`); throw new Error(`Target site is required`);
} }
if (!site) { if (!siteSingle) {
throw new Error( throw new Error(
`Site not found: ${resourceSiteId} in org ${orgId}` `Site not found: ${resourceSiteId} in org ${orgId}`
); );
} }
allSites.push(siteSingle);
}
if (resourceData.sites) {
for (const siteNiceId of resourceData.sites) {
const [site] = await trx
.select({ siteId: sites.siteId })
.from(sites)
.where(
and(
eq(sites.niceId, siteNiceId),
eq(sites.orgId, orgId)
)
)
.limit(1);
if (!site) {
throw new Error(
`Site not found: ${siteId} in org ${orgId}`
);
}
allSites.push(site);
}
}
if (existingResource) { if (existingResource) {
// Update existing resource // Update existing resource
@@ -81,7 +120,6 @@ export async function updateClientResources(
.update(siteResources) .update(siteResources)
.set({ .set({
name: resourceData.name || resourceNiceId, name: resourceData.name || resourceNiceId,
siteId: site.siteId,
mode: resourceData.mode, mode: resourceData.mode,
destination: resourceData.destination, destination: resourceData.destination,
enabled: true, // hardcoded for now enabled: true, // hardcoded for now
@@ -102,6 +140,21 @@ export async function updateClientResources(
const siteResourceId = existingResource.siteResourceId; const siteResourceId = existingResource.siteResourceId;
const orgId = existingResource.orgId; const orgId = existingResource.orgId;
if (updatedResource.networkId) {
await trx
.delete(siteNetworks)
.where(
eq(siteNetworks.networkId, updatedResource.networkId)
);
for (const site of allSites) {
await trx.insert(siteNetworks).values({
siteId: site.siteId,
networkId: updatedResource.networkId
});
}
}
await trx await trx
.delete(clientSiteResources) .delete(clientSiteResources)
.where(eq(clientSiteResources.siteResourceId, siteResourceId)); .where(eq(clientSiteResources.siteResourceId, siteResourceId));
@@ -204,7 +257,9 @@ export async function updateClientResources(
results.push({ results.push({
newSiteResource: updatedResource, newSiteResource: updatedResource,
oldSiteResource: existingResource oldSiteResource: existingResource,
newSites: allSites,
oldSites: existingSiteIds
}); });
} else { } else {
let aliasAddress: string | null = null; let aliasAddress: string | null = null;
@@ -213,13 +268,21 @@ export async function updateClientResources(
aliasAddress = await getNextAvailableAliasAddress(orgId); aliasAddress = await getNextAvailableAliasAddress(orgId);
} }
const [network] = await trx
.insert(networks)
.values({
scope: "resource",
orgId: orgId
})
.returning();
// Create new resource // Create new resource
const [newResource] = await trx const [newResource] = await trx
.insert(siteResources) .insert(siteResources)
.values({ .values({
orgId: orgId, orgId: orgId,
siteId: site.siteId,
niceId: resourceNiceId, niceId: resourceNiceId,
networkId: network.networkId,
name: resourceData.name || resourceNiceId, name: resourceData.name || resourceNiceId,
mode: resourceData.mode, mode: resourceData.mode,
destination: resourceData.destination, destination: resourceData.destination,
@@ -235,6 +298,13 @@ export async function updateClientResources(
const siteResourceId = newResource.siteResourceId; const siteResourceId = newResource.siteResourceId;
for (const site of allSites) {
await trx.insert(siteNetworks).values({
siteId: site.siteId,
networkId: network.networkId
});
}
const [adminRole] = await trx const [adminRole] = await trx
.select() .select()
.from(roles) .from(roles)
@@ -324,7 +394,11 @@ export async function updateClientResources(
`Created new client resource ${newResource.name} (${newResource.siteResourceId}) for org ${orgId}` `Created new client resource ${newResource.name} (${newResource.siteResourceId}) for org ${orgId}`
); );
results.push({ newSiteResource: newResource }); results.push({
newSiteResource: newResource,
newSites: allSites,
oldSites: existingSiteIds
});
} }
} }

View File

@@ -312,7 +312,8 @@ export const ClientResourceSchema = z
.object({ .object({
name: z.string().min(1).max(255), name: z.string().min(1).max(255),
mode: z.enum(["host", "cidr"]), mode: z.enum(["host", "cidr"]),
site: z.string(), site: z.string(), // DEPRECATED IN FAVOR OF sites
sites: z.array(z.string()).optional().default([]),
// protocol: z.enum(["tcp", "udp"]).optional(), // protocol: z.enum(["tcp", "udp"]).optional(),
// proxyPort: z.int().positive().optional(), // proxyPort: z.int().positive().optional(),
// destinationPort: z.int().positive().optional(), // destinationPort: z.int().positive().optional(),

View File

@@ -5,6 +5,8 @@ import {
orgs, orgs,
roles, roles,
roleSiteResources, roleSiteResources,
siteNetworks,
networks,
SiteResource, SiteResource,
siteResources, siteResources,
sites, sites,
@@ -23,7 +25,7 @@ import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { and, eq } from "drizzle-orm"; import { and, eq, inArray } from "drizzle-orm";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { z } from "zod"; import { z } from "zod";
@@ -37,7 +39,7 @@ const createSiteResourceSchema = z
.strictObject({ .strictObject({
name: z.string().min(1).max(255), name: z.string().min(1).max(255),
mode: z.enum(["host", "cidr", "port"]), mode: z.enum(["host", "cidr", "port"]),
siteId: z.int(), siteIds: z.array(z.int()),
// protocol: z.enum(["tcp", "udp"]).optional(), // protocol: z.enum(["tcp", "udp"]).optional(),
// proxyPort: z.int().positive().optional(), // proxyPort: z.int().positive().optional(),
// destinationPort: z.int().positive().optional(), // destinationPort: z.int().positive().optional(),
@@ -159,7 +161,7 @@ export async function createSiteResource(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
const { const {
name, name,
siteId, siteIds,
mode, mode,
// protocol, // protocol,
// proxyPort, // proxyPort,
@@ -178,14 +180,16 @@ export async function createSiteResource(
} = parsedBody.data; } = parsedBody.data;
// Verify the site exists and belongs to the org // Verify the site exists and belongs to the org
const [site] = await db const sitesToAssign = await db
.select() .select()
.from(sites) .from(sites)
.where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId))) .where(and(inArray(sites.siteId, siteIds), eq(sites.orgId, orgId)))
.limit(1); .limit(1);
if (!site) { if (sitesToAssign.length !== siteIds.length) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); return next(
createHttpError(HttpCode.NOT_FOUND, "Some site not found")
);
} }
const [org] = await db const [org] = await db
@@ -287,12 +291,29 @@ export async function createSiteResource(
let newSiteResource: SiteResource | undefined; let newSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const [network] = await trx
.insert(networks)
.values({
scope: "resource",
orgId: orgId
})
.returning();
if (!network) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Failed to create network`
)
);
}
// Create the site resource // Create the site resource
const insertValues: typeof siteResources.$inferInsert = { const insertValues: typeof siteResources.$inferInsert = {
siteId,
niceId, niceId,
orgId, orgId,
name, name,
networkId: network.networkId,
mode: mode as "host" | "cidr", mode: mode as "host" | "cidr",
destination, destination,
enabled, enabled,
@@ -317,6 +338,13 @@ export async function createSiteResource(
//////////////////// update the associations //////////////////// //////////////////// update the associations ////////////////////
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: network.networkId
});
}
const [adminRole] = await trx const [adminRole] = await trx
.select() .select()
.from(roles) .from(roles)
@@ -359,17 +387,22 @@ export async function createSiteResource(
); );
} }
for (const siteToAssign of sitesToAssign) {
const [newt] = await trx const [newt] = await trx
.select() .select()
.from(newts) .from(newts)
.where(eq(newts.siteId, site.siteId)) .where(eq(newts.siteId, siteToAssign.siteId))
.limit(1); .limit(1);
if (!newt) { if (!newt) {
return next( return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found") createHttpError(
HttpCode.NOT_FOUND,
`Newt not found for site ${siteToAssign.siteId}`
)
); );
} }
}
await rebuildClientAssociationsFromSiteResource( await rebuildClientAssociationsFromSiteResource(
newSiteResource, newSiteResource,
@@ -387,7 +420,7 @@ export async function createSiteResource(
} }
logger.info( logger.info(
`Created site resource ${newSiteResource.siteResourceId} for site ${siteId}` `Created site resource ${newSiteResource.siteResourceId} for org ${orgId}`
); );
return response(res, { return response(res, {

View File

@@ -1,4 +1,4 @@
import { db, SiteResource, siteResources, sites } from "@server/db"; import { db, SiteResource, siteNetworks, siteResources, sites } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
@@ -73,9 +73,9 @@ const listAllSiteResourcesByOrgQuerySchema = z.object({
export type ListAllSiteResourcesByOrgResponse = PaginatedResponse<{ export type ListAllSiteResourcesByOrgResponse = PaginatedResponse<{
siteResources: (SiteResource & { siteResources: (SiteResource & {
siteName: string; siteNames: string[];
siteNiceId: string; siteNiceIds: string[];
siteAddress: string | null; siteAddresses: (string | null)[];
})[]; })[];
}>; }>;
@@ -83,7 +83,6 @@ function querySiteResourcesBase() {
return db return db
.select({ .select({
siteResourceId: siteResources.siteResourceId, siteResourceId: siteResources.siteResourceId,
siteId: siteResources.siteId,
orgId: siteResources.orgId, orgId: siteResources.orgId,
niceId: siteResources.niceId, niceId: siteResources.niceId,
name: siteResources.name, name: siteResources.name,
@@ -100,14 +99,19 @@ function querySiteResourcesBase() {
disableIcmp: siteResources.disableIcmp, disableIcmp: siteResources.disableIcmp,
authDaemonMode: siteResources.authDaemonMode, authDaemonMode: siteResources.authDaemonMode,
authDaemonPort: siteResources.authDaemonPort, authDaemonPort: siteResources.authDaemonPort,
siteName: sites.name, networkId: siteResources.networkId,
siteNiceId: sites.niceId, defaultNetworkId: siteResources.defaultNetworkId,
siteAddress: sites.address siteNames: sql<string[]>`array_agg(${sites.name})`,
siteNiceIds: sql<string[]>`array_agg(${sites.niceId})`,
siteAddresses: sql<(string | null)[]>`array_agg(${sites.address})`
}) })
.from(siteResources) .from(siteResources)
.innerJoin(sites, eq(siteResources.siteId, sites.siteId)); .innerJoin(siteNetworks, eq(siteResources.networkId, siteNetworks.networkId))
.innerJoin(sites, eq(siteNetworks.siteId, sites.siteId))
.groupBy(siteResources.siteResourceId);
} }
registry.registerPath({ registry.registerPath({
method: "get", method: "get",
path: "/org/{orgId}/site-resources", path: "/org/{orgId}/site-resources",

View File

@@ -8,7 +8,9 @@ import {
orgs, orgs,
roles, roles,
roleSiteResources, roleSiteResources,
siteNetworks,
sites, sites,
networks,
Transaction, Transaction,
userSiteResources userSiteResources
} from "@server/db"; } from "@server/db";
@@ -16,7 +18,7 @@ import { siteResources, SiteResource } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { eq, and, ne } from "drizzle-orm"; import { eq, and, ne, inArray } from "drizzle-orm";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
@@ -42,7 +44,7 @@ const updateSiteResourceParamsSchema = z.strictObject({
const updateSiteResourceSchema = z const updateSiteResourceSchema = z
.strictObject({ .strictObject({
name: z.string().min(1).max(255).optional(), name: z.string().min(1).max(255).optional(),
siteId: z.int(), siteIds: z.array(z.int()),
// niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(), // niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(),
// mode: z.enum(["host", "cidr", "port"]).optional(), // mode: z.enum(["host", "cidr", "port"]).optional(),
mode: z.enum(["host", "cidr"]).optional(), mode: z.enum(["host", "cidr"]).optional(),
@@ -166,7 +168,7 @@ export async function updateSiteResource(
const { siteResourceId } = parsedParams.data; const { siteResourceId } = parsedParams.data;
const { const {
name, name,
siteId, // because it can change siteIds, // because it can change
mode, mode,
destination, destination,
alias, alias,
@@ -181,16 +183,6 @@ export async function updateSiteResource(
authDaemonMode authDaemonMode
} = parsedBody.data; } = parsedBody.data;
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// Check if site resource exists // Check if site resource exists
const [existingSiteResource] = await db const [existingSiteResource] = await db
.select() .select()
@@ -230,6 +222,24 @@ export async function updateSiteResource(
); );
} }
// Verify the site exists and belongs to the org
const sitesToAssign = await db
.select()
.from(sites)
.where(
and(
inArray(sites.siteId, siteIds),
eq(sites.orgId, existingSiteResource.orgId)
)
)
.limit(1);
if (sitesToAssign.length !== siteIds.length) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Some site not found")
);
}
// Only check if destination is an IP address // Only check if destination is an IP address
const isIp = z const isIp = z
.union([z.ipv4(), z.ipv6()]) .union([z.ipv4(), z.ipv6()])
@@ -247,25 +257,24 @@ export async function updateSiteResource(
); );
} }
let existingSite = site; let sitesChanged = false;
let siteChanged = false; const existingSiteIds = existingSiteResource.networkId
if (existingSiteResource.siteId !== siteId) { ? await db
siteChanged = true;
// get the existing site
[existingSite] = await db
.select() .select()
.from(sites) .from(siteNetworks)
.where(eq(sites.siteId, existingSiteResource.siteId)) .where(
.limit(1); eq(siteNetworks.networkId, existingSiteResource.networkId)
if (!existingSite) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Existing site not found"
) )
); : [];
}
const existingSiteIdSet = new Set(existingSiteIds.map((s) => s.siteId));
const newSiteIdSet = new Set(siteIds);
if (
existingSiteIdSet.size !== newSiteIdSet.size ||
![...existingSiteIdSet].every((id) => newSiteIdSet.has(id))
) {
sitesChanged = true;
} }
// make sure the alias is unique within the org if provided // make sure the alias is unique within the org if provided
@@ -295,7 +304,7 @@ export async function updateSiteResource(
let updatedSiteResource: SiteResource | undefined; let updatedSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place // if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place
if (siteChanged) { if (sitesChanged) {
// delete the existing site resource // delete the existing site resource
await trx await trx
.delete(siteResources) .delete(siteResources)
@@ -321,7 +330,8 @@ export async function updateSiteResource(
const sshPamSet = const sshPamSet =
isLicensedSshPam && isLicensedSshPam &&
(authDaemonPort !== undefined || authDaemonMode !== undefined) (authDaemonPort !== undefined ||
authDaemonMode !== undefined)
? { ? {
...(authDaemonPort !== undefined && { ...(authDaemonPort !== undefined && {
authDaemonPort authDaemonPort
@@ -335,7 +345,6 @@ export async function updateSiteResource(
.update(siteResources) .update(siteResources)
.set({ .set({
name: name, name: name,
siteId: siteId,
mode: mode, mode: mode,
destination: destination, destination: destination,
enabled: enabled, enabled: enabled,
@@ -423,7 +432,8 @@ export async function updateSiteResource(
// Update the site resource // Update the site resource
const sshPamSet = const sshPamSet =
isLicensedSshPam && isLicensedSshPam &&
(authDaemonPort !== undefined || authDaemonMode !== undefined) (authDaemonPort !== undefined ||
authDaemonMode !== undefined)
? { ? {
...(authDaemonPort !== undefined && { ...(authDaemonPort !== undefined && {
authDaemonPort authDaemonPort
@@ -437,7 +447,6 @@ export async function updateSiteResource(
.update(siteResources) .update(siteResources)
.set({ .set({
name: name, name: name,
siteId: siteId,
mode: mode, mode: mode,
destination: destination, destination: destination,
enabled: enabled, enabled: enabled,
@@ -454,6 +463,22 @@ export async function updateSiteResource(
//////////////////// update the associations //////////////////// //////////////////// update the associations ////////////////////
// delete the site - site resources associations
await trx
.delete(siteNetworks)
.where(
eq(siteNetworks.networkId, updatedSiteResource.networkId!)
// TODO: HERE WE FORCE THE NETWORK TO BE DEFINED BUT THE NETWORK CAN GET DELETED and we need to handle that
);
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: updatedSiteResource.networkId!
// TODO: HERE WE FORCE THE NETWORK TO BE DEFINED BUT THE NETWORK CAN GET DELETED and we need to handle that
});
}
await trx await trx
.delete(clientSiteResources) .delete(clientSiteResources)
.where( .where(
@@ -524,13 +549,16 @@ export async function updateSiteResource(
} }
logger.info( logger.info(
`Updated site resource ${siteResourceId} for site ${siteId}` `Updated site resource ${siteResourceId}`
); );
await handleMessagingForUpdatedSiteResource( await handleMessagingForUpdatedSiteResource(
existingSiteResource, existingSiteResource,
updatedSiteResource, updatedSiteResource,
{ siteId: site.siteId, orgId: site.orgId }, siteIds.map((siteId) => ({
siteId,
orgId: existingSiteResource.orgId
})),
trx trx
); );
} }
@@ -557,7 +585,7 @@ export async function updateSiteResource(
export async function handleMessagingForUpdatedSiteResource( export async function handleMessagingForUpdatedSiteResource(
existingSiteResource: SiteResource | undefined, existingSiteResource: SiteResource | undefined,
updatedSiteResource: SiteResource, updatedSiteResource: SiteResource,
site: { siteId: number; orgId: string }, sites: { siteId: number; orgId: string }[],
trx: Transaction trx: Transaction
) { ) {
logger.debug( logger.debug(
@@ -594,6 +622,7 @@ export async function handleMessagingForUpdatedSiteResource(
// if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all // if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all
if (destinationChanged || aliasChanged || portRangesChanged) { if (destinationChanged || aliasChanged || portRangesChanged) {
for (const site of sites) {
const [newt] = await trx const [newt] = await trx
.select() .select()
.from(newts) .from(newts)
@@ -617,10 +646,14 @@ export async function handleMessagingForUpdatedSiteResource(
mergedAllClients mergedAllClients
); );
await updateTargets(newt.newtId, { await updateTargets(
newt.newtId,
{
oldTargets: oldTargets, oldTargets: oldTargets,
newTargets: newTargets newTargets: newTargets
}, newt.version); },
newt.version
);
} }
const olmJobs: Promise<void>[] = []; const olmJobs: Promise<void>[] = [];
@@ -637,13 +670,21 @@ export async function handleMessagingForUpdatedSiteResource(
siteResources.siteResourceId siteResources.siteResourceId
) )
) )
.innerJoin(
siteNetworks,
eq(
siteNetworks.networkId,
siteResources.networkId
// TODO: HERE WE FORCE THE NETWORK TO BE DEFINED BUT THE NETWORK CAN GET DELETED and we need to handle that
)
)
.where( .where(
and( and(
eq( eq(
clientSiteResourcesAssociationsCache.clientId, clientSiteResourcesAssociationsCache.clientId,
client.clientId client.clientId
), ),
eq(siteResources.siteId, site.siteId), eq(siteNetworks.siteId, site.siteId),
eq( eq(
siteResources.destination, siteResources.destination,
existingSiteResource.destination existingSiteResource.destination
@@ -655,6 +696,7 @@ export async function handleMessagingForUpdatedSiteResource(
) )
); );
const oldDestinationStillInUseByASite = const oldDestinationStillInUseByASite =
oldDestinationStillInUseSites.length > 0; oldDestinationStillInUseSites.length > 0;
@@ -662,10 +704,11 @@ export async function handleMessagingForUpdatedSiteResource(
olmJobs.push( olmJobs.push(
updatePeerData( updatePeerData(
client.clientId, client.clientId,
updatedSiteResource.siteId, site.siteId,
destinationChanged destinationChanged
? { ? {
oldRemoteSubnets: !oldDestinationStillInUseByASite oldRemoteSubnets:
!oldDestinationStillInUseByASite
? generateRemoteSubnets([ ? generateRemoteSubnets([
existingSiteResource existingSiteResource
]) ])
@@ -691,4 +734,5 @@ export async function handleMessagingForUpdatedSiteResource(
await Promise.all(olmJobs); await Promise.all(olmJobs);
} }
}
} }