Compare commits

...

14 Commits

Author SHA1 Message Date
Owen
02033f611f First pass at HA 2026-03-23 11:44:02 -07:00
Owen
1366901e24 Adjust build functions 2026-03-22 14:40:57 -07:00
Owen
c4f48f5748 WIP - more conversion 2026-03-22 14:29:47 -07:00
Owen
c48bc71443 Update crud endpoints and ui 2026-03-22 14:18:34 -07:00
Owen
d85496453f Change SSH WIP 2026-03-21 10:40:12 -07:00
Owen
21b91374a3 Merge branch 'private-site-ha' of github.com:fosrl/pangolin into private-site-ha 2026-03-20 17:24:27 -07:00
Owen
a1ce7f54a0 Continue to rebase 2026-03-20 09:17:10 -07:00
Owen
87524fe8ae Remove siteSiteResources 2026-03-19 21:53:52 -07:00
Owen
2093bb5357 Remove siteSiteResources 2026-03-19 21:44:59 -07:00
Owen
6f2e37948c Its many to one now 2026-03-19 21:30:00 -07:00
Owen
b7421e47cc Switch to using networks 2026-03-19 21:22:04 -07:00
Owen
7cbe3d42a1 Working on refactoring 2026-03-19 12:10:04 -07:00
Owen
d8b511b198 Adjust create and update to be many to one 2026-03-18 20:54:49 -07:00
Owen
102a235407 Adjust schema for many to one site resources 2026-03-18 20:54:38 -07:00
20 changed files with 1142 additions and 708 deletions

View File

@@ -216,12 +216,18 @@ export const exitNodes = pgTable("exitNodes", {
export const siteResources = pgTable("siteResources", { export const siteResources = pgTable("siteResources", {
// this is for the clients // this is for the clients
siteResourceId: serial("siteResourceId").primaryKey(), siteResourceId: serial("siteResourceId").primaryKey(),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
orgId: varchar("orgId") orgId: varchar("orgId")
.notNull() .notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }), .references(() => orgs.orgId, { onDelete: "cascade" }),
networkId: integer("networkId").references(() => networks.networkId, {
onDelete: "set null"
}),
defaultNetworkId: integer("defaultNetworkId").references(
() => networks.networkId,
{
onDelete: "restrict"
}
),
niceId: varchar("niceId").notNull(), niceId: varchar("niceId").notNull(),
name: varchar("name").notNull(), name: varchar("name").notNull(),
mode: varchar("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port" mode: varchar("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port"
@@ -241,6 +247,32 @@ export const siteResources = pgTable("siteResources", {
.default("site") .default("site")
}); });
export const networks = pgTable("networks", {
networkId: serial("networkId").primaryKey(),
niceId: text("niceId"),
name: text("name"),
scope: varchar("scope")
.$type<"global" | "resource">()
.notNull()
.default("global"),
orgId: varchar("orgId")
.references(() => orgs.orgId, {
onDelete: "cascade"
})
.notNull()
});
export const siteNetworks = pgTable("siteNetworks", {
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, {
onDelete: "cascade"
}),
networkId: integer("networkId")
.notNull()
.references(() => networks.networkId, { onDelete: "cascade" })
});
export const clientSiteResources = pgTable("clientSiteResources", { export const clientSiteResources = pgTable("clientSiteResources", {
clientId: integer("clientId") clientId: integer("clientId")
.notNull() .notNull()
@@ -1074,3 +1106,4 @@ export type RequestAuditLog = InferSelectModel<typeof requestAuditLog>;
export type RoundTripMessageTracker = InferSelectModel< export type RoundTripMessageTracker = InferSelectModel<
typeof roundTripMessageTracker typeof roundTripMessageTracker
>; >;
export type Network = InferSelectModel<typeof networks>;

View File

@@ -82,6 +82,9 @@ export const sites = sqliteTable("sites", {
exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, { exitNodeId: integer("exitNode").references(() => exitNodes.exitNodeId, {
onDelete: "set null" onDelete: "set null"
}), }),
networkId: integer("networkId").references(() => networks.networkId, {
onDelete: "set null"
}),
name: text("name").notNull(), name: text("name").notNull(),
pubKey: text("pubKey"), pubKey: text("pubKey"),
subnet: text("subnet"), subnet: text("subnet"),
@@ -239,12 +242,16 @@ export const siteResources = sqliteTable("siteResources", {
siteResourceId: integer("siteResourceId").primaryKey({ siteResourceId: integer("siteResourceId").primaryKey({
autoIncrement: true autoIncrement: true
}), }),
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, { onDelete: "cascade" }),
orgId: text("orgId") orgId: text("orgId")
.notNull() .notNull()
.references(() => orgs.orgId, { onDelete: "cascade" }), .references(() => orgs.orgId, { onDelete: "cascade" }),
networkId: integer("networkId").references(() => networks.networkId, {
onDelete: "set null"
}),
defaultNetworkId: integer("defaultNetworkId").references(
() => networks.networkId,
{ onDelete: "restrict" }
),
niceId: text("niceId").notNull(), niceId: text("niceId").notNull(),
name: text("name").notNull(), name: text("name").notNull(),
mode: text("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port" mode: text("mode").$type<"host" | "cidr">().notNull(), // "host" | "cidr" | "port"
@@ -266,6 +273,30 @@ export const siteResources = sqliteTable("siteResources", {
.default("site") .default("site")
}); });
export const networks = sqliteTable("networks", {
networkId: integer("networkId").primaryKey({ autoIncrement: true }),
niceId: text("niceId"),
name: text("name"),
scope: text("scope")
.$type<"global" | "resource">()
.notNull()
.default("global"),
orgId: text("orgId")
.notNull()
.references(() => orgs.orgId, { onDelete: "cascade" })
});
export const siteNetworks = sqliteTable("siteNetworks", {
siteId: integer("siteId")
.notNull()
.references(() => sites.siteId, {
onDelete: "cascade"
}),
networkId: integer("networkId")
.notNull()
.references(() => networks.networkId, { onDelete: "cascade" })
});
export const clientSiteResources = sqliteTable("clientSiteResources", { export const clientSiteResources = sqliteTable("clientSiteResources", {
clientId: integer("clientId") clientId: integer("clientId")
.notNull() .notNull()
@@ -1158,6 +1189,7 @@ export type ApiKey = InferSelectModel<typeof apiKeys>;
export type ApiKeyAction = InferSelectModel<typeof apiKeyActions>; export type ApiKeyAction = InferSelectModel<typeof apiKeyActions>;
export type ApiKeyOrg = InferSelectModel<typeof apiKeyOrg>; export type ApiKeyOrg = InferSelectModel<typeof apiKeyOrg>;
export type SiteResource = InferSelectModel<typeof siteResources>; export type SiteResource = InferSelectModel<typeof siteResources>;
export type Network = InferSelectModel<typeof networks>;
export type OrgDomains = InferSelectModel<typeof orgDomains>; export type OrgDomains = InferSelectModel<typeof orgDomains>;
export type SetupToken = InferSelectModel<typeof setupTokens>; export type SetupToken = InferSelectModel<typeof setupTokens>;
export type HostMeta = InferSelectModel<typeof hostMeta>; export type HostMeta = InferSelectModel<typeof hostMeta>;

View File

@@ -121,8 +121,8 @@ export async function applyBlueprint({
for (const result of clientResourcesResults) { for (const result of clientResourcesResults) {
if ( if (
result.oldSiteResource && result.oldSiteResource &&
result.oldSiteResource.siteId != JSON.stringify(result.newSites?.sort()) !==
result.newSiteResource.siteId JSON.stringify(result.oldSites?.sort())
) { ) {
// query existing associations // query existing associations
const existingRoleIds = await trx const existingRoleIds = await trx
@@ -222,13 +222,15 @@ export async function applyBlueprint({
trx trx
); );
} else { } else {
const [newSite] = await trx let good = true;
for (const newSite of result.newSites) {
const [site] = await trx
.select() .select()
.from(sites) .from(sites)
.innerJoin(newts, eq(sites.siteId, newts.siteId)) .innerJoin(newts, eq(sites.siteId, newts.siteId))
.where( .where(
and( and(
eq(sites.siteId, result.newSiteResource.siteId), eq(sites.siteId, newSite.siteId),
eq(sites.orgId, orgId), eq(sites.orgId, orgId),
eq(sites.type, "newt"), eq(sites.type, "newt"),
isNotNull(sites.pubKey) isNotNull(sites.pubKey)
@@ -236,24 +238,30 @@ export async function applyBlueprint({
) )
.limit(1); .limit(1);
if (!newSite) { if (!site) {
logger.debug( logger.debug(
`No newt site found for client resource ${result.newSiteResource.siteResourceId}, skipping target update` `No newt sites found for client resource ${result.newSiteResource.siteResourceId}, skipping target update`
); );
continue; good = false;
break;
} }
logger.debug( logger.debug(
`Updating client resource ${result.newSiteResource.siteResourceId} on site ${newSite.sites.siteId}` `Updating client resource ${result.newSiteResource.siteResourceId} on site ${newSite.siteId}`
); );
}
if (!good) {
continue;
}
await handleMessagingForUpdatedSiteResource( await handleMessagingForUpdatedSiteResource(
result.oldSiteResource, result.oldSiteResource,
result.newSiteResource, result.newSiteResource,
{ result.newSites.map((site) => ({
siteId: newSite.sites.siteId, siteId: site.siteId,
orgId: newSite.sites.orgId orgId: result.newSiteResource.orgId
}, })),
trx trx
); );
} }

View File

@@ -3,12 +3,15 @@ import {
clientSiteResources, clientSiteResources,
roles, roles,
roleSiteResources, roleSiteResources,
Site,
SiteResource, SiteResource,
siteNetworks,
siteResources, siteResources,
Transaction, Transaction,
userOrgs, userOrgs,
users, users,
userSiteResources userSiteResources,
networks
} from "@server/db"; } from "@server/db";
import { sites } from "@server/db"; import { sites } from "@server/db";
import { eq, and, ne, inArray, or } from "drizzle-orm"; import { eq, and, ne, inArray, or } from "drizzle-orm";
@@ -19,6 +22,8 @@ import { getNextAvailableAliasAddress } from "../ip";
export type ClientResourcesResults = { export type ClientResourcesResults = {
newSiteResource: SiteResource; newSiteResource: SiteResource;
oldSiteResource?: SiteResource; oldSiteResource?: SiteResource;
newSites: { siteId: number }[];
oldSites: { siteId: number }[];
}[]; }[];
export async function updateClientResources( export async function updateClientResources(
@@ -43,12 +48,21 @@ export async function updateClientResources(
) )
.limit(1); .limit(1);
const existingSiteIds = existingResource?.networkId
? await trx
.select({ siteId: sites.siteId })
.from(siteNetworks)
.where(eq(siteNetworks.networkId, existingResource.networkId))
: [];
let allSites: { siteId: number }[] = [];
if (resourceData.site) {
let siteSingle;
const resourceSiteId = resourceData.site; const resourceSiteId = resourceData.site;
let site;
if (resourceSiteId) { if (resourceSiteId) {
// Look up site by niceId // Look up site by niceId
[site] = await trx [siteSingle] = await trx
.select({ siteId: sites.siteId }) .select({ siteId: sites.siteId })
.from(sites) .from(sites)
.where( .where(
@@ -60,20 +74,45 @@ export async function updateClientResources(
.limit(1); .limit(1);
} else if (siteId) { } else if (siteId) {
// Use the provided siteId directly, but verify it belongs to the org // Use the provided siteId directly, but verify it belongs to the org
[site] = await trx [siteSingle] = await trx
.select({ siteId: sites.siteId }) .select({ siteId: sites.siteId })
.from(sites) .from(sites)
.where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId))) .where(
and(eq(sites.siteId, siteId), eq(sites.orgId, orgId))
)
.limit(1); .limit(1);
} else { } else {
throw new Error(`Target site is required`); throw new Error(`Target site is required`);
} }
if (!site) { if (!siteSingle) {
throw new Error( throw new Error(
`Site not found: ${resourceSiteId} in org ${orgId}` `Site not found: ${resourceSiteId} in org ${orgId}`
); );
} }
allSites.push(siteSingle);
}
if (resourceData.sites) {
for (const siteNiceId of resourceData.sites) {
const [site] = await trx
.select({ siteId: sites.siteId })
.from(sites)
.where(
and(
eq(sites.niceId, siteNiceId),
eq(sites.orgId, orgId)
)
)
.limit(1);
if (!site) {
throw new Error(
`Site not found: ${siteId} in org ${orgId}`
);
}
allSites.push(site);
}
}
if (existingResource) { if (existingResource) {
// Update existing resource // Update existing resource
@@ -81,7 +120,6 @@ export async function updateClientResources(
.update(siteResources) .update(siteResources)
.set({ .set({
name: resourceData.name || resourceNiceId, name: resourceData.name || resourceNiceId,
siteId: site.siteId,
mode: resourceData.mode, mode: resourceData.mode,
destination: resourceData.destination, destination: resourceData.destination,
enabled: true, // hardcoded for now enabled: true, // hardcoded for now
@@ -102,6 +140,21 @@ export async function updateClientResources(
const siteResourceId = existingResource.siteResourceId; const siteResourceId = existingResource.siteResourceId;
const orgId = existingResource.orgId; const orgId = existingResource.orgId;
if (updatedResource.networkId) {
await trx
.delete(siteNetworks)
.where(
eq(siteNetworks.networkId, updatedResource.networkId)
);
for (const site of allSites) {
await trx.insert(siteNetworks).values({
siteId: site.siteId,
networkId: updatedResource.networkId
});
}
}
await trx await trx
.delete(clientSiteResources) .delete(clientSiteResources)
.where(eq(clientSiteResources.siteResourceId, siteResourceId)); .where(eq(clientSiteResources.siteResourceId, siteResourceId));
@@ -204,7 +257,9 @@ export async function updateClientResources(
results.push({ results.push({
newSiteResource: updatedResource, newSiteResource: updatedResource,
oldSiteResource: existingResource oldSiteResource: existingResource,
newSites: allSites,
oldSites: existingSiteIds
}); });
} else { } else {
let aliasAddress: string | null = null; let aliasAddress: string | null = null;
@@ -213,13 +268,22 @@ export async function updateClientResources(
aliasAddress = await getNextAvailableAliasAddress(orgId); aliasAddress = await getNextAvailableAliasAddress(orgId);
} }
const [network] = await trx
.insert(networks)
.values({
scope: "resource",
orgId: orgId
})
.returning();
// Create new resource // Create new resource
const [newResource] = await trx const [newResource] = await trx
.insert(siteResources) .insert(siteResources)
.values({ .values({
orgId: orgId, orgId: orgId,
siteId: site.siteId,
niceId: resourceNiceId, niceId: resourceNiceId,
networkId: network.networkId,
defaultNetworkId: network.networkId,
name: resourceData.name || resourceNiceId, name: resourceData.name || resourceNiceId,
mode: resourceData.mode, mode: resourceData.mode,
destination: resourceData.destination, destination: resourceData.destination,
@@ -235,6 +299,13 @@ export async function updateClientResources(
const siteResourceId = newResource.siteResourceId; const siteResourceId = newResource.siteResourceId;
for (const site of allSites) {
await trx.insert(siteNetworks).values({
siteId: site.siteId,
networkId: network.networkId
});
}
const [adminRole] = await trx const [adminRole] = await trx
.select() .select()
.from(roles) .from(roles)
@@ -324,7 +395,11 @@ export async function updateClientResources(
`Created new client resource ${newResource.name} (${newResource.siteResourceId}) for org ${orgId}` `Created new client resource ${newResource.name} (${newResource.siteResourceId}) for org ${orgId}`
); );
results.push({ newSiteResource: newResource }); results.push({
newSiteResource: newResource,
newSites: allSites,
oldSites: existingSiteIds
});
} }
} }

View File

@@ -312,7 +312,8 @@ export const ClientResourceSchema = z
.object({ .object({
name: z.string().min(1).max(255), name: z.string().min(1).max(255),
mode: z.enum(["host", "cidr"]), mode: z.enum(["host", "cidr"]),
site: z.string(), site: z.string(), // DEPRECATED IN FAVOR OF sites
sites: z.array(z.string()).optional().default([]),
// protocol: z.enum(["tcp", "udp"]).optional(), // protocol: z.enum(["tcp", "udp"]).optional(),
// proxyPort: z.int().positive().optional(), // proxyPort: z.int().positive().optional(),
// destinationPort: z.int().positive().optional(), // destinationPort: z.int().positive().optional(),

View File

@@ -11,6 +11,7 @@ import {
roleSiteResources, roleSiteResources,
Site, Site,
SiteResource, SiteResource,
siteNetworks,
siteResources, siteResources,
sites, sites,
Transaction, Transaction,
@@ -47,15 +48,23 @@ export async function getClientSiteResourceAccess(
siteResource: SiteResource, siteResource: SiteResource,
trx: Transaction | typeof db = db trx: Transaction | typeof db = db
) { ) {
// get the site // get all sites associated with this siteResource via its network
const [site] = await trx const sitesList = siteResource.networkId
? await trx
.select() .select()
.from(sites) .from(sites)
.where(eq(sites.siteId, siteResource.siteId)) .innerJoin(
.limit(1); siteNetworks,
eq(siteNetworks.siteId, sites.siteId)
)
.where(eq(siteNetworks.networkId, siteResource.networkId))
.then((rows) => rows.map((row) => row.sites))
: [];
if (!site) { if (sitesList.length === 0) {
throw new Error(`Site with ID ${siteResource.siteId} not found`); logger.warn(
`No sites found for siteResource ${siteResource.siteResourceId} with networkId ${siteResource.networkId}`
);
} }
const roleIds = await trx const roleIds = await trx
@@ -136,7 +145,7 @@ export async function getClientSiteResourceAccess(
const mergedAllClientIds = mergedAllClients.map((c) => c.clientId); const mergedAllClientIds = mergedAllClients.map((c) => c.clientId);
return { return {
site, sitesList,
mergedAllClients, mergedAllClients,
mergedAllClientIds mergedAllClientIds
}; };
@@ -152,17 +161,18 @@ export async function rebuildClientAssociationsFromSiteResource(
subnet: string | null; subnet: string | null;
}[]; }[];
}> { }> {
const siteId = siteResource.siteId; const { sitesList, mergedAllClients, mergedAllClientIds } =
const { site, mergedAllClients, mergedAllClientIds } =
await getClientSiteResourceAccess(siteResource, trx); await getClientSiteResourceAccess(siteResource, trx);
/////////// process the client-siteResource associations /////////// /////////// process the client-siteResource associations ///////////
// get all of the clients associated with other resources on this site // get all of the clients associated with other resources in the same network,
const allUpdatedClientsFromOtherResourcesOnThisSite = await trx // joined through siteNetworks so we know which siteId each client belongs to
const allUpdatedClientsFromOtherResourcesOnThisSite = siteResource.networkId
? await trx
.select({ .select({
clientId: clientSiteResourcesAssociationsCache.clientId clientId: clientSiteResourcesAssociationsCache.clientId,
siteId: siteNetworks.siteId
}) })
.from(clientSiteResourcesAssociationsCache) .from(clientSiteResourcesAssociationsCache)
.innerJoin( .innerJoin(
@@ -172,20 +182,30 @@ export async function rebuildClientAssociationsFromSiteResource(
siteResources.siteResourceId siteResources.siteResourceId
) )
) )
.innerJoin(
siteNetworks,
eq(siteNetworks.networkId, siteResources.networkId)
)
.where( .where(
and( and(
eq(siteResources.siteId, siteId), eq(siteResources.networkId, siteResource.networkId),
ne(siteResources.siteResourceId, siteResource.siteResourceId) ne(
siteResources.siteResourceId,
siteResource.siteResourceId
) )
); )
)
: [];
const allClientIdsFromOtherResourcesOnThisSite = Array.from( // Build a per-site map so the loop below can check by siteId rather than
new Set( // across the entire network.
allUpdatedClientsFromOtherResourcesOnThisSite.map( const clientsFromOtherResourcesBySite = new Map<number, Set<number>>();
(row) => row.clientId for (const row of allUpdatedClientsFromOtherResourcesOnThisSite) {
) if (!clientsFromOtherResourcesBySite.has(row.siteId)) {
) clientsFromOtherResourcesBySite.set(row.siteId, new Set());
); }
clientsFromOtherResourcesBySite.get(row.siteId)!.add(row.clientId);
}
const existingClientSiteResources = await trx const existingClientSiteResources = await trx
.select({ .select({
@@ -259,31 +279,39 @@ export async function rebuildClientAssociationsFromSiteResource(
/////////// process the client-site associations /////////// /////////// process the client-site associations ///////////
for (const site of sitesList) {
const siteId = site.siteId;
const existingClientSites = await trx const existingClientSites = await trx
.select({ .select({
clientId: clientSitesAssociationsCache.clientId clientId: clientSitesAssociationsCache.clientId
}) })
.from(clientSitesAssociationsCache) .from(clientSitesAssociationsCache)
.where(eq(clientSitesAssociationsCache.siteId, siteResource.siteId)); .where(eq(clientSitesAssociationsCache.siteId, siteId));
const existingClientSiteIds = existingClientSites.map( const existingClientSiteIds = existingClientSites.map(
(row) => row.clientId (row) => row.clientId
); );
// Get full client details for existing clients (needed for sending delete messages) // Get full client details for existing clients (needed for sending delete messages)
const existingClients = await trx const existingClients =
existingClientSiteIds.length > 0
? await trx
.select({ .select({
clientId: clients.clientId, clientId: clients.clientId,
pubKey: clients.pubKey, pubKey: clients.pubKey,
subnet: clients.subnet subnet: clients.subnet
}) })
.from(clients) .from(clients)
.where(inArray(clients.clientId, existingClientSiteIds)); .where(inArray(clients.clientId, existingClientSiteIds))
: [];
const otherResourceClientIds = clientsFromOtherResourcesBySite.get(siteId) ?? new Set<number>();
const clientSitesToAdd = mergedAllClientIds.filter( const clientSitesToAdd = mergedAllClientIds.filter(
(clientId) => (clientId) =>
!existingClientSiteIds.includes(clientId) && !existingClientSiteIds.includes(clientId) &&
!allClientIdsFromOtherResourcesOnThisSite.includes(clientId) // dont remove if there is still another connection for another site resource !otherResourceClientIds.has(clientId) // dont add if already connected via another site resource
); );
const clientSitesToInsert = clientSitesToAdd.map((clientId) => ({ const clientSitesToInsert = clientSitesToAdd.map((clientId) => ({
@@ -302,7 +330,7 @@ export async function rebuildClientAssociationsFromSiteResource(
const clientSitesToRemove = existingClientSiteIds.filter( const clientSitesToRemove = existingClientSiteIds.filter(
(clientId) => (clientId) =>
!mergedAllClientIds.includes(clientId) && !mergedAllClientIds.includes(clientId) &&
!allClientIdsFromOtherResourcesOnThisSite.includes(clientId) // dont remove if there is still another connection for another site resource !otherResourceClientIds.has(clientId) // dont remove if there is still another connection for another site resource
); );
if (clientSitesToRemove.length > 0) { if (clientSitesToRemove.length > 0) {
@@ -319,8 +347,6 @@ export async function rebuildClientAssociationsFromSiteResource(
); );
} }
/////////// send the messages ///////////
// Now handle the messages to add/remove peers on both the newt and olm sides // Now handle the messages to add/remove peers on both the newt and olm sides
await handleMessagesForSiteClients( await handleMessagesForSiteClients(
site, site,
@@ -331,10 +357,12 @@ export async function rebuildClientAssociationsFromSiteResource(
clientSitesToRemove, clientSitesToRemove,
trx trx
); );
}
// Handle subnet proxy target updates for the resource associations // Handle subnet proxy target updates for the resource associations
await handleSubnetProxyTargetUpdates( await handleSubnetProxyTargetUpdates(
siteResource, siteResource,
sitesList,
mergedAllClients, mergedAllClients,
existingResourceClients, existingResourceClients,
clientSiteResourcesToAdd, clientSiteResourcesToAdd,
@@ -623,6 +651,7 @@ export async function updateClientSiteDestinations(
async function handleSubnetProxyTargetUpdates( async function handleSubnetProxyTargetUpdates(
siteResource: SiteResource, siteResource: SiteResource,
sitesList: Site[],
allClients: { allClients: {
clientId: number; clientId: number;
pubKey: string | null; pubKey: string | null;
@@ -637,22 +666,26 @@ async function handleSubnetProxyTargetUpdates(
clientSiteResourcesToRemove: number[], clientSiteResourcesToRemove: number[],
trx: Transaction | typeof db = db trx: Transaction | typeof db = db
): Promise<void> { ): Promise<void> {
const proxyJobs: Promise<any>[] = [];
const olmJobs: Promise<any>[] = [];
for (const siteData of sitesList) {
const siteId = siteData.siteId;
// Get the newt for this site // Get the newt for this site
const [newt] = await trx const [newt] = await trx
.select() .select()
.from(newts) .from(newts)
.where(eq(newts.siteId, siteResource.siteId)) .where(eq(newts.siteId, siteId))
.limit(1); .limit(1);
if (!newt) { if (!newt) {
logger.warn( logger.warn(
`Newt not found for site ${siteResource.siteId}, skipping subnet proxy target updates` `Newt not found for site ${siteId}, skipping subnet proxy target updates`
); );
return; continue;
} }
const proxyJobs = [];
const olmJobs = [];
// Generate targets for added associations // Generate targets for added associations
if (clientSiteResourcesToAdd.length > 0) { if (clientSiteResourcesToAdd.length > 0) {
const addedClients = allClients.filter((client) => const addedClients = allClients.filter((client) =>
@@ -667,7 +700,7 @@ async function handleSubnetProxyTargetUpdates(
if (targetsToAdd.length > 0) { if (targetsToAdd.length > 0) {
logger.info( logger.info(
`Adding ${targetsToAdd.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}` `Adding ${targetsToAdd.length} subnet proxy targets for siteResource ${siteResource.siteResourceId} on site ${siteId}`
); );
proxyJobs.push( proxyJobs.push(
addSubnetProxyTargets( addSubnetProxyTargets(
@@ -682,7 +715,7 @@ async function handleSubnetProxyTargetUpdates(
olmJobs.push( olmJobs.push(
addPeerData( addPeerData(
client.clientId, client.clientId,
siteResource.siteId, siteId,
generateRemoteSubnets([siteResource]), generateRemoteSubnets([siteResource]),
generateAliasConfig([siteResource]) generateAliasConfig([siteResource])
) )
@@ -707,7 +740,7 @@ async function handleSubnetProxyTargetUpdates(
if (targetsToRemove.length > 0) { if (targetsToRemove.length > 0) {
logger.info( logger.info(
`Removing ${targetsToRemove.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}` `Removing ${targetsToRemove.length} subnet proxy targets for siteResource ${siteResource.siteResourceId} on site ${siteId}`
); );
proxyJobs.push( proxyJobs.push(
removeSubnetProxyTargets( removeSubnetProxyTargets(
@@ -719,7 +752,11 @@ async function handleSubnetProxyTargetUpdates(
} }
for (const client of removedClients) { for (const client of removedClients) {
// Check if this client still has access to another resource on this site with the same destination // Check if this client still has access to another resource
// on this specific site with the same destination. We scope
// by siteId (via siteNetworks) rather than networkId because
// removePeerData operates per-site — a resource on a different
// site sharing the same network should not block removal here.
const destinationStillInUse = await trx const destinationStillInUse = await trx
.select() .select()
.from(siteResources) .from(siteResources)
@@ -730,13 +767,17 @@ async function handleSubnetProxyTargetUpdates(
siteResources.siteResourceId siteResources.siteResourceId
) )
) )
.innerJoin(
siteNetworks,
eq(siteNetworks.networkId, siteResources.networkId)
)
.where( .where(
and( and(
eq( eq(
clientSiteResourcesAssociationsCache.clientId, clientSiteResourcesAssociationsCache.clientId,
client.clientId client.clientId
), ),
eq(siteResources.siteId, siteResource.siteId), eq(siteNetworks.siteId, siteId),
eq( eq(
siteResources.destination, siteResources.destination,
siteResource.destination siteResource.destination
@@ -757,7 +798,7 @@ async function handleSubnetProxyTargetUpdates(
olmJobs.push( olmJobs.push(
removePeerData( removePeerData(
client.clientId, client.clientId,
siteResource.siteId, siteId,
remoteSubnetsToRemove, remoteSubnetsToRemove,
generateAliasConfig([siteResource]) generateAliasConfig([siteResource])
) )
@@ -765,6 +806,7 @@ async function handleSubnetProxyTargetUpdates(
} }
} }
} }
}
await Promise.all(proxyJobs); await Promise.all(proxyJobs);
} }
@@ -868,10 +910,25 @@ export async function rebuildClientAssociationsFromClient(
) )
: []; : [];
// Group by siteId for site-level associations // Group by siteId for site-level associations — look up via siteNetworks since
const newSiteIds = Array.from( // siteResources no longer carries a direct siteId column.
new Set(newSiteResources.map((sr) => sr.siteId)) const networkIds = Array.from(
new Set(
newSiteResources
.map((sr) => sr.networkId)
.filter((id): id is number => id !== null)
)
); );
const newSiteIds =
networkIds.length > 0
? await trx
.select({ siteId: siteNetworks.siteId })
.from(siteNetworks)
.where(inArray(siteNetworks.networkId, networkIds))
.then((rows) =>
Array.from(new Set(rows.map((r) => r.siteId)))
)
: [];
/////////// Process client-siteResource associations /////////// /////////// Process client-siteResource associations ///////////
@@ -1144,13 +1201,45 @@ async function handleMessagesForClientResources(
resourcesToAdd.includes(r.siteResourceId) resourcesToAdd.includes(r.siteResourceId)
); );
// Build (resource, siteId) pairs by looking up siteNetworks for each resource's networkId
const addedNetworkIds = Array.from(
new Set(
addedResources
.map((r) => r.networkId)
.filter((id): id is number => id !== null)
)
);
const addedSiteNetworkRows =
addedNetworkIds.length > 0
? await trx
.select({
networkId: siteNetworks.networkId,
siteId: siteNetworks.siteId
})
.from(siteNetworks)
.where(inArray(siteNetworks.networkId, addedNetworkIds))
: [];
const addedNetworkToSites = new Map<number, number[]>();
for (const row of addedSiteNetworkRows) {
if (!addedNetworkToSites.has(row.networkId)) {
addedNetworkToSites.set(row.networkId, []);
}
addedNetworkToSites.get(row.networkId)!.push(row.siteId);
}
// Group by site for proxy updates // Group by site for proxy updates
const addedBySite = new Map<number, SiteResource[]>(); const addedBySite = new Map<number, SiteResource[]>();
for (const resource of addedResources) { for (const resource of addedResources) {
if (!addedBySite.has(resource.siteId)) { const siteIds =
addedBySite.set(resource.siteId, []); resource.networkId != null
? (addedNetworkToSites.get(resource.networkId) ?? [])
: [];
for (const siteId of siteIds) {
if (!addedBySite.has(siteId)) {
addedBySite.set(siteId, []);
}
addedBySite.get(siteId)!.push(resource);
} }
addedBySite.get(resource.siteId)!.push(resource);
} }
// Add subnet proxy targets for each site // Add subnet proxy targets for each site
@@ -1192,7 +1281,7 @@ async function handleMessagesForClientResources(
olmJobs.push( olmJobs.push(
addPeerData( addPeerData(
client.clientId, client.clientId,
resource.siteId, siteId,
generateRemoteSubnets([resource]), generateRemoteSubnets([resource]),
generateAliasConfig([resource]) generateAliasConfig([resource])
) )
@@ -1204,7 +1293,7 @@ async function handleMessagesForClientResources(
error.message.includes("not found") error.message.includes("not found")
) { ) {
logger.debug( logger.debug(
`Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal` `Olm data not found for client ${client.clientId} and site ${siteId}, skipping addition`
); );
} else { } else {
throw error; throw error;
@@ -1221,13 +1310,45 @@ async function handleMessagesForClientResources(
.from(siteResources) .from(siteResources)
.where(inArray(siteResources.siteResourceId, resourcesToRemove)); .where(inArray(siteResources.siteResourceId, resourcesToRemove));
// Build (resource, siteId) pairs via siteNetworks
const removedNetworkIds = Array.from(
new Set(
removedResources
.map((r) => r.networkId)
.filter((id): id is number => id !== null)
)
);
const removedSiteNetworkRows =
removedNetworkIds.length > 0
? await trx
.select({
networkId: siteNetworks.networkId,
siteId: siteNetworks.siteId
})
.from(siteNetworks)
.where(inArray(siteNetworks.networkId, removedNetworkIds))
: [];
const removedNetworkToSites = new Map<number, number[]>();
for (const row of removedSiteNetworkRows) {
if (!removedNetworkToSites.has(row.networkId)) {
removedNetworkToSites.set(row.networkId, []);
}
removedNetworkToSites.get(row.networkId)!.push(row.siteId);
}
// Group by site for proxy updates // Group by site for proxy updates
const removedBySite = new Map<number, SiteResource[]>(); const removedBySite = new Map<number, SiteResource[]>();
for (const resource of removedResources) { for (const resource of removedResources) {
if (!removedBySite.has(resource.siteId)) { const siteIds =
removedBySite.set(resource.siteId, []); resource.networkId != null
? (removedNetworkToSites.get(resource.networkId) ?? [])
: [];
for (const siteId of siteIds) {
if (!removedBySite.has(siteId)) {
removedBySite.set(siteId, []);
}
removedBySite.get(siteId)!.push(resource);
} }
removedBySite.get(resource.siteId)!.push(resource);
} }
// Remove subnet proxy targets for each site // Remove subnet proxy targets for each site
@@ -1265,7 +1386,11 @@ async function handleMessagesForClientResources(
} }
try { try {
// Check if this client still has access to another resource on this site with the same destination // Check if this client still has access to another resource
// on this specific site with the same destination. We scope
// by siteId (via siteNetworks) rather than networkId because
// removePeerData operates per-site — a resource on a different
// site sharing the same network should not block removal here.
const destinationStillInUse = await trx const destinationStillInUse = await trx
.select() .select()
.from(siteResources) .from(siteResources)
@@ -1276,13 +1401,17 @@ async function handleMessagesForClientResources(
siteResources.siteResourceId siteResources.siteResourceId
) )
) )
.innerJoin(
siteNetworks,
eq(siteNetworks.networkId, siteResources.networkId)
)
.where( .where(
and( and(
eq( eq(
clientSiteResourcesAssociationsCache.clientId, clientSiteResourcesAssociationsCache.clientId,
client.clientId client.clientId
), ),
eq(siteResources.siteId, resource.siteId), eq(siteNetworks.siteId, siteId),
eq( eq(
siteResources.destination, siteResources.destination,
resource.destination resource.destination
@@ -1304,7 +1433,7 @@ async function handleMessagesForClientResources(
olmJobs.push( olmJobs.push(
removePeerData( removePeerData(
client.clientId, client.clientId,
resource.siteId, siteId,
remoteSubnetsToRemove, remoteSubnetsToRemove,
generateAliasConfig([resource]) generateAliasConfig([resource])
) )
@@ -1316,7 +1445,7 @@ async function handleMessagesForClientResources(
error.message.includes("not found") error.message.includes("not found")
) { ) {
logger.debug( logger.debug(
`Olm data not found for client ${client.clientId} and site ${resource.siteId}, skipping removal` `Olm data not found for client ${client.clientId} and site ${siteId}, skipping removal`
); );
} else { } else {
throw error; throw error;

View File

@@ -21,7 +21,7 @@ import {
roles, roles,
roundTripMessageTracker, roundTripMessageTracker,
siteResources, siteResources,
sites, siteNetworks,
userOrgs userOrgs
} from "@server/db"; } from "@server/db";
import { isLicensedOrSubscribed } from "#private/lib/isLicencedOrSubscribed"; import { isLicensedOrSubscribed } from "#private/lib/isLicencedOrSubscribed";
@@ -62,10 +62,12 @@ const bodySchema = z
export type SignSshKeyResponse = { export type SignSshKeyResponse = {
certificate: string; certificate: string;
messageIds: number[];
messageId: number; messageId: number;
sshUsername: string; sshUsername: string;
sshHost: string; sshHost: string;
resourceId: number; resourceId: number;
siteIds: number[];
siteId: number; siteId: number;
keyId: string; keyId: string;
validPrincipals: string[]; validPrincipals: string[];
@@ -250,10 +252,7 @@ export async function signSshKey(
.update(userOrgs) .update(userOrgs)
.set({ pamUsername: usernameToUse }) .set({ pamUsername: usernameToUse })
.where( .where(
and( and(eq(userOrgs.orgId, orgId), eq(userOrgs.userId, userId))
eq(userOrgs.orgId, orgId),
eq(userOrgs.userId, userId)
)
); );
} else { } else {
usernameToUse = userOrg.pamUsername; usernameToUse = userOrg.pamUsername;
@@ -374,21 +373,12 @@ export async function signSshKey(
const homedir = roleRow?.sshCreateHomeDir ?? null; const homedir = roleRow?.sshCreateHomeDir ?? null;
const sudoMode = roleRow?.sshSudoMode ?? "none"; const sudoMode = roleRow?.sshSudoMode ?? "none";
// get the site const sites = await db
const [newt] = await db .select({ siteId: siteNetworks.siteId })
.select() .from(siteNetworks)
.from(newts) .where(eq(siteNetworks.networkId, resource.networkId!));
.where(eq(newts.siteId, resource.siteId))
.limit(1);
if (!newt) { const siteIds = sites.map((site) => site.siteId);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Site associated with resource not found"
)
);
}
// Sign the public key // Sign the public key
const now = BigInt(Math.floor(Date.now() / 1000)); const now = BigInt(Math.floor(Date.now() / 1000));
@@ -402,6 +392,24 @@ export async function signSshKey(
validBefore: now + validFor validBefore: now + validFor
}); });
const messageIds: number[] = [];
for (const siteId of siteIds) {
// get the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Site associated with resource not found"
)
);
}
const [message] = await db const [message] = await db
.insert(roundTripMessageTracker) .insert(roundTripMessageTracker)
.values({ .values({
@@ -420,6 +428,8 @@ export async function signSshKey(
); );
} }
messageIds.push(message.messageId);
await sendToClient(newt.newtId, { await sendToClient(newt.newtId, {
type: `newt/pam/connection`, type: `newt/pam/connection`,
data: { data: {
@@ -439,6 +449,7 @@ export async function signSshKey(
} }
} }
}); });
}
const expiresIn = Number(validFor); // seconds const expiresIn = Number(validFor); // seconds
@@ -459,18 +470,20 @@ export async function signSshKey(
metadata: JSON.stringify({ metadata: JSON.stringify({
resourceId: resource.siteResourceId, resourceId: resource.siteResourceId,
resource: resource.name, resource: resource.name,
siteId: resource.siteId, siteIds: siteIds
}) })
}); });
return response<SignSshKeyResponse>(res, { return response<SignSshKeyResponse>(res, {
data: { data: {
certificate: cert.certificate, certificate: cert.certificate,
messageId: message.messageId, messageIds: messageIds,
messageId: messageIds[0], // just pick the first one for backward compatibility
sshUsername: usernameToUse, sshUsername: usernameToUse,
sshHost: sshHost, sshHost: sshHost,
resourceId: resource.siteResourceId, resourceId: resource.siteResourceId,
siteId: resource.siteId, siteIds: siteIds,
siteId: siteIds[0], // just pick the first one for backward compatibility
keyId: cert.keyId, keyId: cert.keyId,
validPrincipals: cert.validPrincipals, validPrincipals: cert.validPrincipals,
validAfter: cert.validAfter.toISOString(), validAfter: cert.validAfter.toISOString(),

View File

@@ -4,8 +4,10 @@ import {
clientSitesAssociationsCache, clientSitesAssociationsCache,
db, db,
ExitNode, ExitNode,
networks,
resources, resources,
Site, Site,
siteNetworks,
siteResources, siteResources,
targetHealthCheck, targetHealthCheck,
targets targets
@@ -137,11 +139,14 @@ export async function buildClientConfigurationForNewtClient(
// Filter out any null values from peers that didn't have an olm // Filter out any null values from peers that didn't have an olm
const validPeers = peers.filter((peer) => peer !== null); const validPeers = peers.filter((peer) => peer !== null);
// Get all enabled site resources for this site // Get all enabled site resources for this site by joining through siteNetworks and networks
const allSiteResources = await db const allSiteResources = await db
.select() .select()
.from(siteResources) .from(siteResources)
.where(eq(siteResources.siteId, siteId)); .innerJoin(networks, eq(siteResources.networkId, networks.networkId))
.innerJoin(siteNetworks, eq(networks.networkId, siteNetworks.networkId))
.where(eq(siteNetworks.siteId, siteId))
.then((rows) => rows.map((r) => r.siteResources));
const targetsToSend: SubnetProxyTarget[] = []; const targetsToSend: SubnetProxyTarget[] = [];

View File

@@ -4,6 +4,8 @@ import {
clientSitesAssociationsCache, clientSitesAssociationsCache,
db, db,
exitNodes, exitNodes,
networks,
siteNetworks,
siteResources, siteResources,
sites sites
} from "@server/db"; } from "@server/db";
@@ -59,9 +61,17 @@ export async function buildSiteConfigurationForOlmClient(
clientSiteResourcesAssociationsCache.siteResourceId clientSiteResourcesAssociationsCache.siteResourceId
) )
) )
.innerJoin(
networks,
eq(siteResources.networkId, networks.networkId)
)
.innerJoin(
siteNetworks,
eq(networks.networkId, siteNetworks.networkId)
)
.where( .where(
and( and(
eq(siteResources.siteId, site.siteId), eq(siteNetworks.siteId, site.siteId),
eq( eq(
clientSiteResourcesAssociationsCache.clientId, clientSiteResourcesAssociationsCache.clientId,
client.clientId client.clientId
@@ -69,6 +79,7 @@ export async function buildSiteConfigurationForOlmClient(
) )
); );
if (jitMode) { if (jitMode) {
// Add site configuration to the array // Add site configuration to the array
siteConfigurations.push({ siteConfigurations.push({

View File

@@ -4,10 +4,12 @@ import {
db, db,
exitNodes, exitNodes,
Site, Site,
siteResources siteNetworks,
siteResources,
sites
} from "@server/db"; } from "@server/db";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { clients, Olm, sites } from "@server/db"; import { clients, Olm } from "@server/db";
import { and, eq, or } from "drizzle-orm"; import { and, eq, or } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
import { initPeerAddHandshake } from "./peers"; import { initPeerAddHandshake } from "./peers";
@@ -44,20 +46,31 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
const { siteId, resourceId, chainId } = message.data; const { siteId, resourceId, chainId } = message.data;
let site: Site | null = null; const sendCancel = async () => {
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: { chainId }
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
};
let sitesToProcess: Site[] = [];
if (siteId) { if (siteId) {
// get the site
const [siteRes] = await db const [siteRes] = await db
.select() .select()
.from(sites) .from(sites)
.where(eq(sites.siteId, siteId)) .where(eq(sites.siteId, siteId))
.limit(1); .limit(1);
if (siteRes) { if (siteRes) {
site = siteRes; sitesToProcess = [siteRes];
} }
} } else if (resourceId) {
if (resourceId && !site) {
const resources = await db const resources = await db
.select() .select()
.from(siteResources) .from(siteResources)
@@ -72,27 +85,17 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
); );
if (!resources || resources.length === 0) { if (!resources || resources.length === 0) {
logger.error(`handleOlmServerPeerAddMessage: Resource not found`); logger.error(
// cancel the request from the olm side to not keep doing this `handleOlmServerInitAddPeerHandshake: Resource not found`
await sendToClient( );
olm.olmId, await sendCancel();
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return; return;
} }
if (resources.length > 1) { if (resources.length > 1) {
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches // error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
logger.error( logger.error(
`handleOlmServerPeerAddMessage: Multiple resources found matching the criteria` `handleOlmServerInitAddPeerHandshake: Multiple resources found matching the criteria`
); );
return; return;
} }
@@ -117,47 +120,61 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
if (currentResourceAssociationCaches.length === 0) { if (currentResourceAssociationCaches.length === 0) {
logger.error( logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}` `handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
); );
// cancel the request from the olm side to not keep doing this await sendCancel();
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return; return;
} }
const siteIdFromResource = resource.siteId; if (!resource.networkId) {
logger.error(
`handleOlmServerInitAddPeerHandshake: Resource ${resource.siteResourceId} has no network`
);
await sendCancel();
return;
}
// get the site // Get all sites associated with this resource's network via siteNetworks
const [siteRes] = await db const siteRows = await db
.select({ siteId: siteNetworks.siteId })
.from(siteNetworks)
.where(eq(siteNetworks.networkId, resource.networkId));
if (!siteRows || siteRows.length === 0) {
logger.error(
`handleOlmServerInitAddPeerHandshake: No sites found for resource ${resource.siteResourceId}`
);
await sendCancel();
return;
}
// Fetch full site objects for all network members
const foundSites = await Promise.all(
siteRows.map(async ({ siteId: sid }) => {
const [s] = await db
.select() .select()
.from(sites) .from(sites)
.where(eq(sites.siteId, siteIdFromResource)); .where(eq(sites.siteId, sid))
if (!siteRes) { .limit(1);
logger.error( return s ?? null;
`handleOlmServerPeerAddMessage: Site with ID ${site} not found` })
); );
sitesToProcess = foundSites.filter((s): s is Site => s !== null);
}
if (sitesToProcess.length === 0) {
logger.error(
`handleOlmServerInitAddPeerHandshake: No sites to process`
);
await sendCancel();
return; return;
} }
site = siteRes; let handshakeInitiated = false;
}
if (!site) { for (const site of sitesToProcess) {
logger.error(`handleOlmServerPeerAddMessage: Site not found`); // Check if the client can access this site using the cache
return;
}
// check if the client can access this site using the cache
const currentSiteAssociationCaches = await db const currentSiteAssociationCaches = await db
.select() .select()
.from(clientSitesAssociationsCache) .from(clientSitesAssociationsCache)
@@ -169,46 +186,19 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
); );
if (currentSiteAssociationCaches.length === 0) { if (currentSiteAssociationCaches.length === 0) {
logger.error( logger.warn(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}` `handleOlmServerInitAddPeerHandshake: Client ${client.clientId} does not have access to site ${site.siteId}, skipping`
); );
// cancel the request from the olm side to not keep doing this continue;
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
} }
if (!site.exitNodeId) { if (!site.exitNodeId) {
logger.error( logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node` `handleOlmServerInitAddPeerHandshake: Site ${site.siteId} has no exit node, skipping`
); );
// cancel the request from the olm side to not keep doing this continue;
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
} }
// get the exit node from the side
const [exitNode] = await db const [exitNode] = await db
.select() .select()
.from(exitNodes) .from(exitNodes)
@@ -216,15 +206,13 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
if (!exitNode) { if (!exitNode) {
logger.error( logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node` `handleOlmServerInitAddPeerHandshake: Exit node not found for site ${site.siteId}, skipping`
); );
return; continue;
} }
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch // Trigger the peer add handshake — if the peer was already added this will be a no-op
// if it has already been added this will be a no-op
await initPeerAddHandshake( await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId, client.clientId,
{ {
siteId: site.siteId, siteId: site.siteId,
@@ -237,5 +225,15 @@ export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
chainId chainId
); );
handshakeInitiated = true;
}
if (!handshakeInitiated) {
logger.error(
`handleOlmServerInitAddPeerHandshake: No accessible sites with valid exit nodes found, cancelling chain`
);
await sendCancel();
}
return; return;
}; };

View File

@@ -1,43 +1,25 @@
import { import {
Client,
clientSiteResourcesAssociationsCache, clientSiteResourcesAssociationsCache,
db, db,
ExitNode, networks,
Org, siteNetworks,
orgs,
roleClients,
roles,
siteResources, siteResources,
Transaction,
userClients,
userOrgs,
users
} from "@server/db"; } from "@server/db";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { import {
clients, clients,
clientSitesAssociationsCache, clientSitesAssociationsCache,
exitNodes,
Olm, Olm,
olms,
sites sites
} from "@server/db"; } from "@server/db";
import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm"; import { and, eq, inArray, isNotNull, isNull } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
import logger from "@server/logger"; import logger from "@server/logger";
import { listExitNodes } from "#dynamic/lib/exitNodes";
import { import {
generateAliasConfig, generateAliasConfig,
getNextAvailableClientSubnet
} from "@server/lib/ip"; } from "@server/lib/ip";
import { generateRemoteSubnets } from "@server/lib/ip"; import { generateRemoteSubnets } from "@server/lib/ip";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { validateSessionToken } from "@server/auth/sessions/app";
import config from "@server/lib/config";
import { import {
addPeer as newtAddPeer, addPeer as newtAddPeer,
deletePeer as newtDeletePeer
} from "@server/routers/newt/peers"; } from "@server/routers/newt/peers";
export const handleOlmServerPeerAddMessage: MessageHandler = async ( export const handleOlmServerPeerAddMessage: MessageHandler = async (
@@ -153,14 +135,22 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
clientSiteResourcesAssociationsCache.siteResourceId clientSiteResourcesAssociationsCache.siteResourceId
) )
) )
.where( .innerJoin(
networks,
eq(siteResources.networkId, networks.networkId)
)
.innerJoin(
siteNetworks,
and( and(
eq(siteResources.siteId, site.siteId), eq(networks.networkId, siteNetworks.networkId),
eq(siteNetworks.siteId, site.siteId)
)
)
.where(
eq( eq(
clientSiteResourcesAssociationsCache.clientId, clientSiteResourcesAssociationsCache.clientId,
client.clientId client.clientId
) )
)
); );
// Return connect message with all site configurations // Return connect message with all site configurations

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db, Site, siteResources } from "@server/db"; import { db, Site, siteNetworks, siteResources } from "@server/db";
import { newts, newtSessions, sites } from "@server/db"; import { newts, newtSessions, sites } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import response from "@server/lib/response"; import response from "@server/lib/response";
@@ -71,19 +71,24 @@ export async function deleteSite(
await deletePeer(site.exitNodeId!, site.pubKey); await deletePeer(site.exitNodeId!, site.pubKey);
} }
} else if (site.type == "newt") { } else if (site.type == "newt") {
// delete all of the site resources on this site const networks = await trx
const siteResourcesOnSite = trx .select({ networkId: siteNetworks.networkId })
.delete(siteResources) .from(siteNetworks)
.where(eq(siteResources.siteId, siteId)) .where(eq(siteNetworks.siteId, siteId));
.returning();
// loop through them // loop through them
for (const removedSiteResource of await siteResourcesOnSite) { for (const network of await networks) {
const [siteResource] = await trx
.select()
.from(siteResources)
.where(eq(siteResources.networkId, network.networkId));
if (siteResource) {
await rebuildClientAssociationsFromSiteResource( await rebuildClientAssociationsFromSiteResource(
removedSiteResource, siteResource,
trx trx
); );
} }
}
// get the newt on the site by querying the newt table for siteId // get the newt on the site by querying the newt table for siteId
const [deletedNewt] = await trx const [deletedNewt] = await trx

View File

@@ -5,6 +5,8 @@ import {
orgs, orgs,
roles, roles,
roleSiteResources, roleSiteResources,
siteNetworks,
networks,
SiteResource, SiteResource,
siteResources, siteResources,
sites, sites,
@@ -23,7 +25,7 @@ import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import { and, eq } from "drizzle-orm"; import { and, eq, inArray } from "drizzle-orm";
import { NextFunction, Request, Response } from "express"; import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { z } from "zod"; import { z } from "zod";
@@ -37,7 +39,7 @@ const createSiteResourceSchema = z
.strictObject({ .strictObject({
name: z.string().min(1).max(255), name: z.string().min(1).max(255),
mode: z.enum(["host", "cidr", "port"]), mode: z.enum(["host", "cidr", "port"]),
siteId: z.int(), siteIds: z.array(z.int()),
// protocol: z.enum(["tcp", "udp"]).optional(), // protocol: z.enum(["tcp", "udp"]).optional(),
// proxyPort: z.int().positive().optional(), // proxyPort: z.int().positive().optional(),
// destinationPort: z.int().positive().optional(), // destinationPort: z.int().positive().optional(),
@@ -159,7 +161,7 @@ export async function createSiteResource(
const { orgId } = parsedParams.data; const { orgId } = parsedParams.data;
const { const {
name, name,
siteId, siteIds,
mode, mode,
// protocol, // protocol,
// proxyPort, // proxyPort,
@@ -178,14 +180,16 @@ export async function createSiteResource(
} = parsedBody.data; } = parsedBody.data;
// Verify the site exists and belongs to the org // Verify the site exists and belongs to the org
const [site] = await db const sitesToAssign = await db
.select() .select()
.from(sites) .from(sites)
.where(and(eq(sites.siteId, siteId), eq(sites.orgId, orgId))) .where(and(inArray(sites.siteId, siteIds), eq(sites.orgId, orgId)))
.limit(1); .limit(1);
if (!site) { if (sitesToAssign.length !== siteIds.length) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); return next(
createHttpError(HttpCode.NOT_FOUND, "Some site not found")
);
} }
const [org] = await db const [org] = await db
@@ -287,12 +291,29 @@ export async function createSiteResource(
let newSiteResource: SiteResource | undefined; let newSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
const [network] = await trx
.insert(networks)
.values({
scope: "resource",
orgId: orgId
})
.returning();
if (!network) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Failed to create network`
)
);
}
// Create the site resource // Create the site resource
const insertValues: typeof siteResources.$inferInsert = { const insertValues: typeof siteResources.$inferInsert = {
siteId,
niceId, niceId,
orgId, orgId,
name, name,
networkId: network.networkId,
mode: mode as "host" | "cidr", mode: mode as "host" | "cidr",
destination, destination,
enabled, enabled,
@@ -317,6 +338,13 @@ export async function createSiteResource(
//////////////////// update the associations //////////////////// //////////////////// update the associations ////////////////////
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: network.networkId
});
}
const [adminRole] = await trx const [adminRole] = await trx
.select() .select()
.from(roles) .from(roles)
@@ -359,17 +387,22 @@ export async function createSiteResource(
); );
} }
for (const siteToAssign of sitesToAssign) {
const [newt] = await trx const [newt] = await trx
.select() .select()
.from(newts) .from(newts)
.where(eq(newts.siteId, site.siteId)) .where(eq(newts.siteId, siteToAssign.siteId))
.limit(1); .limit(1);
if (!newt) { if (!newt) {
return next( return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found") createHttpError(
HttpCode.NOT_FOUND,
`Newt not found for site ${siteToAssign.siteId}`
)
); );
} }
}
await rebuildClientAssociationsFromSiteResource( await rebuildClientAssociationsFromSiteResource(
newSiteResource, newSiteResource,
@@ -387,7 +420,7 @@ export async function createSiteResource(
} }
logger.info( logger.info(
`Created site resource ${newSiteResource.siteResourceId} for site ${siteId}` `Created site resource ${newSiteResource.siteResourceId} for org ${orgId}`
); );
return response(res, { return response(res, {

View File

@@ -70,17 +70,18 @@ export async function deleteSiteResource(
.where(and(eq(siteResources.siteResourceId, siteResourceId))) .where(and(eq(siteResources.siteResourceId, siteResourceId)))
.returning(); .returning();
const [newt] = await trx // not sure why this is here...
.select() // const [newt] = await trx
.from(newts) // .select()
.where(eq(newts.siteId, removedSiteResource.siteId)) // .from(newts)
.limit(1); // .where(eq(newts.siteId, removedSiteResource.siteId))
// .limit(1);
if (!newt) { // if (!newt) {
return next( // return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found") // createHttpError(HttpCode.NOT_FOUND, "Newt not found")
); // );
} // }
await rebuildClientAssociationsFromSiteResource( await rebuildClientAssociationsFromSiteResource(
removedSiteResource, removedSiteResource,

View File

@@ -17,38 +17,34 @@ const getSiteResourceParamsSchema = z.strictObject({
.transform((val) => (val ? Number(val) : undefined)) .transform((val) => (val ? Number(val) : undefined))
.pipe(z.int().positive().optional()) .pipe(z.int().positive().optional())
.optional(), .optional(),
siteId: z.string().transform(Number).pipe(z.int().positive()),
niceId: z.string().optional(), niceId: z.string().optional(),
orgId: z.string() orgId: z.string()
}); });
async function query( async function query(
siteResourceId?: number, siteResourceId?: number,
siteId?: number,
niceId?: string, niceId?: string,
orgId?: string orgId?: string
) { ) {
if (siteResourceId && siteId && orgId) { if (siteResourceId && orgId) {
const [siteResource] = await db const [siteResource] = await db
.select() .select()
.from(siteResources) .from(siteResources)
.where( .where(
and( and(
eq(siteResources.siteResourceId, siteResourceId), eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId) eq(siteResources.orgId, orgId)
) )
) )
.limit(1); .limit(1);
return siteResource; return siteResource;
} else if (niceId && siteId && orgId) { } else if (niceId && orgId) {
const [siteResource] = await db const [siteResource] = await db
.select() .select()
.from(siteResources) .from(siteResources)
.where( .where(
and( and(
eq(siteResources.niceId, niceId), eq(siteResources.niceId, niceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId) eq(siteResources.orgId, orgId)
) )
) )
@@ -84,7 +80,6 @@ registry.registerPath({
request: { request: {
params: z.object({ params: z.object({
niceId: z.string(), niceId: z.string(),
siteId: z.number(),
orgId: z.string() orgId: z.string()
}) })
}, },
@@ -107,10 +102,10 @@ export async function getSiteResource(
); );
} }
const { siteResourceId, siteId, niceId, orgId } = parsedParams.data; const { siteResourceId, niceId, orgId } = parsedParams.data;
// Get the site resource // Get the site resource
const siteResource = await query(siteResourceId, siteId, niceId, orgId); const siteResource = await query(siteResourceId, niceId, orgId);
if (!siteResource) { if (!siteResource) {
return next( return next(

View File

@@ -1,4 +1,4 @@
import { db, SiteResource, siteResources, sites } from "@server/db"; import { db, SiteResource, siteNetworks, siteResources, sites } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
@@ -73,9 +73,10 @@ const listAllSiteResourcesByOrgQuerySchema = z.object({
export type ListAllSiteResourcesByOrgResponse = PaginatedResponse<{ export type ListAllSiteResourcesByOrgResponse = PaginatedResponse<{
siteResources: (SiteResource & { siteResources: (SiteResource & {
siteName: string; siteIds: number[];
siteNiceId: string; siteNames: string[];
siteAddress: string | null; siteNiceIds: string[];
siteAddresses: (string | null)[];
})[]; })[];
}>; }>;
@@ -83,7 +84,6 @@ function querySiteResourcesBase() {
return db return db
.select({ .select({
siteResourceId: siteResources.siteResourceId, siteResourceId: siteResources.siteResourceId,
siteId: siteResources.siteId,
orgId: siteResources.orgId, orgId: siteResources.orgId,
niceId: siteResources.niceId, niceId: siteResources.niceId,
name: siteResources.name, name: siteResources.name,
@@ -100,14 +100,20 @@ function querySiteResourcesBase() {
disableIcmp: siteResources.disableIcmp, disableIcmp: siteResources.disableIcmp,
authDaemonMode: siteResources.authDaemonMode, authDaemonMode: siteResources.authDaemonMode,
authDaemonPort: siteResources.authDaemonPort, authDaemonPort: siteResources.authDaemonPort,
siteName: sites.name, networkId: siteResources.networkId,
siteNiceId: sites.niceId, defaultNetworkId: siteResources.defaultNetworkId,
siteAddress: sites.address siteNames: sql<string[]>`array_agg(${sites.name})`,
siteNiceIds: sql<string[]>`array_agg(${sites.niceId})`,
siteIds: sql<number[]>`array_agg(${sites.siteId})`,
siteAddresses: sql<(string | null)[]>`array_agg(${sites.address})`
}) })
.from(siteResources) .from(siteResources)
.innerJoin(sites, eq(siteResources.siteId, sites.siteId)); .innerJoin(siteNetworks, eq(siteResources.networkId, siteNetworks.networkId))
.innerJoin(sites, eq(siteNetworks.siteId, sites.siteId))
.groupBy(siteResources.siteResourceId);
} }
registry.registerPath({ registry.registerPath({
method: "get", method: "get",
path: "/org/{orgId}/site-resources", path: "/org/{orgId}/site-resources",

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db, networks, siteNetworks } from "@server/db";
import { siteResources, sites, SiteResource } from "@server/db"; import { siteResources, sites, SiteResource } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -108,13 +108,21 @@ export async function listSiteResources(
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found")); return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
} }
// Get site resources // Get site resources by joining networks to siteResources via siteNetworks
const siteResourcesList = await db const siteResourcesList = await db
.select() .select()
.from(siteResources) .from(siteNetworks)
.innerJoin(
networks,
eq(siteNetworks.networkId, networks.networkId)
)
.innerJoin(
siteResources,
eq(siteResources.networkId, networks.networkId)
)
.where( .where(
and( and(
eq(siteResources.siteId, siteId), eq(siteNetworks.siteId, siteId),
eq(siteResources.orgId, orgId) eq(siteResources.orgId, orgId)
) )
) )
@@ -128,6 +136,7 @@ export async function listSiteResources(
.limit(limit) .limit(limit)
.offset(offset); .offset(offset);
return response(res, { return response(res, {
data: { siteResources: siteResourcesList }, data: { siteResources: siteResourcesList },
success: true, success: true,

View File

@@ -8,7 +8,9 @@ import {
orgs, orgs,
roles, roles,
roleSiteResources, roleSiteResources,
siteNetworks,
sites, sites,
networks,
Transaction, Transaction,
userSiteResources userSiteResources
} from "@server/db"; } from "@server/db";
@@ -16,7 +18,7 @@ import { siteResources, SiteResource } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import { eq, and, ne } from "drizzle-orm"; import { eq, and, ne, inArray } from "drizzle-orm";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
@@ -42,7 +44,7 @@ const updateSiteResourceParamsSchema = z.strictObject({
const updateSiteResourceSchema = z const updateSiteResourceSchema = z
.strictObject({ .strictObject({
name: z.string().min(1).max(255).optional(), name: z.string().min(1).max(255).optional(),
siteId: z.int(), siteIds: z.array(z.int()),
// niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(), // niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(),
// mode: z.enum(["host", "cidr", "port"]).optional(), // mode: z.enum(["host", "cidr", "port"]).optional(),
mode: z.enum(["host", "cidr"]).optional(), mode: z.enum(["host", "cidr"]).optional(),
@@ -166,7 +168,7 @@ export async function updateSiteResource(
const { siteResourceId } = parsedParams.data; const { siteResourceId } = parsedParams.data;
const { const {
name, name,
siteId, // because it can change siteIds, // because it can change
mode, mode,
destination, destination,
alias, alias,
@@ -181,16 +183,6 @@ export async function updateSiteResource(
authDaemonMode authDaemonMode
} = parsedBody.data; } = parsedBody.data;
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
return next(createHttpError(HttpCode.NOT_FOUND, "Site not found"));
}
// Check if site resource exists // Check if site resource exists
const [existingSiteResource] = await db const [existingSiteResource] = await db
.select() .select()
@@ -230,6 +222,24 @@ export async function updateSiteResource(
); );
} }
// Verify the site exists and belongs to the org
const sitesToAssign = await db
.select()
.from(sites)
.where(
and(
inArray(sites.siteId, siteIds),
eq(sites.orgId, existingSiteResource.orgId)
)
)
.limit(1);
if (sitesToAssign.length !== siteIds.length) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Some site not found")
);
}
// Only check if destination is an IP address // Only check if destination is an IP address
const isIp = z const isIp = z
.union([z.ipv4(), z.ipv6()]) .union([z.ipv4(), z.ipv6()])
@@ -247,25 +257,24 @@ export async function updateSiteResource(
); );
} }
let existingSite = site; let sitesChanged = false;
let siteChanged = false; const existingSiteIds = existingSiteResource.networkId
if (existingSiteResource.siteId !== siteId) { ? await db
siteChanged = true;
// get the existing site
[existingSite] = await db
.select() .select()
.from(sites) .from(siteNetworks)
.where(eq(sites.siteId, existingSiteResource.siteId)) .where(
.limit(1); eq(siteNetworks.networkId, existingSiteResource.networkId)
if (!existingSite) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"Existing site not found"
) )
); : [];
}
const existingSiteIdSet = new Set(existingSiteIds.map((s) => s.siteId));
const newSiteIdSet = new Set(siteIds);
if (
existingSiteIdSet.size !== newSiteIdSet.size ||
![...existingSiteIdSet].every((id) => newSiteIdSet.has(id))
) {
sitesChanged = true;
} }
// make sure the alias is unique within the org if provided // make sure the alias is unique within the org if provided
@@ -295,7 +304,7 @@ export async function updateSiteResource(
let updatedSiteResource: SiteResource | undefined; let updatedSiteResource: SiteResource | undefined;
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
// if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place // if the site is changed we need to delete and recreate the resource to avoid complications with the rebuild function otherwise we can just update in place
if (siteChanged) { if (sitesChanged) {
// delete the existing site resource // delete the existing site resource
await trx await trx
.delete(siteResources) .delete(siteResources)
@@ -321,7 +330,8 @@ export async function updateSiteResource(
const sshPamSet = const sshPamSet =
isLicensedSshPam && isLicensedSshPam &&
(authDaemonPort !== undefined || authDaemonMode !== undefined) (authDaemonPort !== undefined ||
authDaemonMode !== undefined)
? { ? {
...(authDaemonPort !== undefined && { ...(authDaemonPort !== undefined && {
authDaemonPort authDaemonPort
@@ -335,7 +345,6 @@ export async function updateSiteResource(
.update(siteResources) .update(siteResources)
.set({ .set({
name: name, name: name,
siteId: siteId,
mode: mode, mode: mode,
destination: destination, destination: destination,
enabled: enabled, enabled: enabled,
@@ -423,7 +432,8 @@ export async function updateSiteResource(
// Update the site resource // Update the site resource
const sshPamSet = const sshPamSet =
isLicensedSshPam && isLicensedSshPam &&
(authDaemonPort !== undefined || authDaemonMode !== undefined) (authDaemonPort !== undefined ||
authDaemonMode !== undefined)
? { ? {
...(authDaemonPort !== undefined && { ...(authDaemonPort !== undefined && {
authDaemonPort authDaemonPort
@@ -437,7 +447,6 @@ export async function updateSiteResource(
.update(siteResources) .update(siteResources)
.set({ .set({
name: name, name: name,
siteId: siteId,
mode: mode, mode: mode,
destination: destination, destination: destination,
enabled: enabled, enabled: enabled,
@@ -454,6 +463,20 @@ export async function updateSiteResource(
//////////////////// update the associations //////////////////// //////////////////// update the associations ////////////////////
// delete the site - site resources associations
await trx
.delete(siteNetworks)
.where(
eq(siteNetworks.networkId, updatedSiteResource.networkId!)
);
for (const siteId of siteIds) {
await trx.insert(siteNetworks).values({
siteId: siteId,
networkId: updatedSiteResource.networkId!
});
}
await trx await trx
.delete(clientSiteResources) .delete(clientSiteResources)
.where( .where(
@@ -524,13 +547,16 @@ export async function updateSiteResource(
} }
logger.info( logger.info(
`Updated site resource ${siteResourceId} for site ${siteId}` `Updated site resource ${siteResourceId}`
); );
await handleMessagingForUpdatedSiteResource( await handleMessagingForUpdatedSiteResource(
existingSiteResource, existingSiteResource,
updatedSiteResource, updatedSiteResource,
{ siteId: site.siteId, orgId: site.orgId }, siteIds.map((siteId) => ({
siteId,
orgId: existingSiteResource.orgId
})),
trx trx
); );
} }
@@ -557,7 +583,7 @@ export async function updateSiteResource(
export async function handleMessagingForUpdatedSiteResource( export async function handleMessagingForUpdatedSiteResource(
existingSiteResource: SiteResource | undefined, existingSiteResource: SiteResource | undefined,
updatedSiteResource: SiteResource, updatedSiteResource: SiteResource,
site: { siteId: number; orgId: string }, sites: { siteId: number; orgId: string }[],
trx: Transaction trx: Transaction
) { ) {
logger.debug( logger.debug(
@@ -594,6 +620,7 @@ export async function handleMessagingForUpdatedSiteResource(
// if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all // if the existingSiteResource is undefined (new resource) we don't need to do anything here, the rebuild above handled it all
if (destinationChanged || aliasChanged || portRangesChanged) { if (destinationChanged || aliasChanged || portRangesChanged) {
for (const site of sites) {
const [newt] = await trx const [newt] = await trx
.select() .select()
.from(newts) .from(newts)
@@ -617,10 +644,14 @@ export async function handleMessagingForUpdatedSiteResource(
mergedAllClients mergedAllClients
); );
await updateTargets(newt.newtId, { await updateTargets(
newt.newtId,
{
oldTargets: oldTargets, oldTargets: oldTargets,
newTargets: newTargets newTargets: newTargets
}, newt.version); },
newt.version
);
} }
const olmJobs: Promise<void>[] = []; const olmJobs: Promise<void>[] = [];
@@ -637,13 +668,20 @@ export async function handleMessagingForUpdatedSiteResource(
siteResources.siteResourceId siteResources.siteResourceId
) )
) )
.innerJoin(
siteNetworks,
eq(
siteNetworks.networkId,
siteResources.networkId
)
)
.where( .where(
and( and(
eq( eq(
clientSiteResourcesAssociationsCache.clientId, clientSiteResourcesAssociationsCache.clientId,
client.clientId client.clientId
), ),
eq(siteResources.siteId, site.siteId), eq(siteNetworks.siteId, site.siteId),
eq( eq(
siteResources.destination, siteResources.destination,
existingSiteResource.destination existingSiteResource.destination
@@ -655,6 +693,7 @@ export async function handleMessagingForUpdatedSiteResource(
) )
); );
const oldDestinationStillInUseByASite = const oldDestinationStillInUseByASite =
oldDestinationStillInUseSites.length > 0; oldDestinationStillInUseSites.length > 0;
@@ -662,10 +701,11 @@ export async function handleMessagingForUpdatedSiteResource(
olmJobs.push( olmJobs.push(
updatePeerData( updatePeerData(
client.clientId, client.clientId,
updatedSiteResource.siteId, site.siteId,
destinationChanged destinationChanged
? { ? {
oldRemoteSubnets: !oldDestinationStillInUseByASite oldRemoteSubnets:
!oldDestinationStillInUseByASite
? generateRemoteSubnets([ ? generateRemoteSubnets([
existingSiteResource existingSiteResource
]) ])
@@ -691,4 +731,5 @@ export async function handleMessagingForUpdatedSiteResource(
await Promise.all(olmJobs); await Promise.all(olmJobs);
} }
}
} }

View File

@@ -60,17 +60,17 @@ export default async function ClientResourcesPage(
id: siteResource.siteResourceId, id: siteResource.siteResourceId,
name: siteResource.name, name: siteResource.name,
orgId: params.orgId, orgId: params.orgId,
siteName: siteResource.siteName, siteNames: siteResource.siteNames,
siteAddress: siteResource.siteAddress || null, siteAddresses: siteResource.siteAddresses || null,
mode: siteResource.mode || ("port" as any), mode: siteResource.mode || ("port" as any),
// protocol: siteResource.protocol, // protocol: siteResource.protocol,
// proxyPort: siteResource.proxyPort, // proxyPort: siteResource.proxyPort,
siteId: siteResource.siteId, siteIds: siteResource.siteIds,
destination: siteResource.destination, destination: siteResource.destination,
// destinationPort: siteResource.destinationPort, // destinationPort: siteResource.destinationPort,
alias: siteResource.alias || null, alias: siteResource.alias || null,
aliasAddress: siteResource.aliasAddress || null, aliasAddress: siteResource.aliasAddress || null,
siteNiceId: siteResource.siteNiceId, siteNiceIds: siteResource.siteNiceIds,
niceId: siteResource.niceId, niceId: siteResource.niceId,
tcpPortRangeString: siteResource.tcpPortRangeString || null, tcpPortRangeString: siteResource.tcpPortRangeString || null,
udpPortRangeString: siteResource.udpPortRangeString || null, udpPortRangeString: siteResource.udpPortRangeString || null,

View File

@@ -21,6 +21,7 @@ import {
ArrowUp10Icon, ArrowUp10Icon,
ArrowUpDown, ArrowUpDown,
ArrowUpRight, ArrowUpRight,
ChevronDown,
ChevronsUpDownIcon, ChevronsUpDownIcon,
MoreHorizontal MoreHorizontal
} from "lucide-react"; } from "lucide-react";
@@ -43,14 +44,14 @@ export type InternalResourceRow = {
id: number; id: number;
name: string; name: string;
orgId: string; orgId: string;
siteName: string; siteNames: string[];
siteAddress: string | null; siteAddresses: (string | null)[];
siteIds: number[];
siteNiceIds: string[];
// mode: "host" | "cidr" | "port"; // mode: "host" | "cidr" | "port";
mode: "host" | "cidr"; mode: "host" | "cidr";
// protocol: string | null; // protocol: string | null;
// proxyPort: number | null; // proxyPort: number | null;
siteId: number;
siteNiceId: string;
destination: string; destination: string;
// destinationPort: number | null; // destinationPort: number | null;
alias: string | null; alias: string | null;
@@ -136,6 +137,60 @@ export default function ClientResourcesTable({
} }
}; };
function SiteCell({ resourceRow }: { resourceRow: InternalResourceRow }) {
const { siteNames, siteNiceIds, orgId } = resourceRow;
if (!siteNames || siteNames.length === 0) {
return <span>-</span>;
}
if (siteNames.length === 1) {
return (
<Link
href={`/${orgId}/settings/sites/${siteNiceIds[0]}`}
>
<Button variant="outline">
{siteNames[0]}
<ArrowUpRight className="ml-2 h-4 w-4" />
</Button>
</Link>
);
}
return (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
variant="outline"
size="sm"
className="flex items-center gap-2"
>
<span>
{siteNames.length} {t("sites")}
</span>
<ChevronDown className="h-3 w-3" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="start">
{siteNames.map((siteName, idx) => (
<DropdownMenuItem
key={siteNiceIds[idx]}
asChild
>
<Link
href={`/${orgId}/settings/sites/${siteNiceIds[idx]}`}
className="flex items-center gap-2 cursor-pointer"
>
{siteName}
<ArrowUpRight className="h-3 w-3" />
</Link>
</DropdownMenuItem>
))}
</DropdownMenuContent>
</DropdownMenu>
);
}
const internalColumns: ExtendedColumnDef<InternalResourceRow>[] = [ const internalColumns: ExtendedColumnDef<InternalResourceRow>[] = [
{ {
accessorKey: "name", accessorKey: "name",
@@ -185,21 +240,11 @@ export default function ClientResourcesTable({
} }
}, },
{ {
accessorKey: "siteName", accessorKey: "siteNames",
friendlyName: t("site"), friendlyName: t("site"),
header: () => <span className="p-3">{t("site")}</span>, header: () => <span className="p-3">{t("site")}</span>,
cell: ({ row }) => { cell: ({ row }) => {
const resourceRow = row.original; return <SiteCell resourceRow={row.original} />;
return (
<Link
href={`/${resourceRow.orgId}/settings/sites/${resourceRow.siteNiceId}`}
>
<Button variant="outline">
{resourceRow.siteName}
<ArrowUpRight className="ml-2 h-4 w-4" />
</Button>
</Link>
);
} }
}, },
{ {
@@ -399,7 +444,7 @@ export default function ClientResourcesTable({
onConfirm={async () => onConfirm={async () =>
deleteInternalResource( deleteInternalResource(
selectedInternalResource!.id, selectedInternalResource!.id,
selectedInternalResource!.siteId selectedInternalResource!.siteIds[0]
) )
} }
string={selectedInternalResource.name} string={selectedInternalResource.name}
@@ -433,7 +478,11 @@ export default function ClientResourcesTable({
<EditInternalResourceDialog <EditInternalResourceDialog
open={isEditDialogOpen} open={isEditDialogOpen}
setOpen={setIsEditDialogOpen} setOpen={setIsEditDialogOpen}
resource={editingResource} resource={{
...editingResource,
siteName: editingResource.siteNames[0] ?? "",
siteId: editingResource.siteIds[0]
}}
orgId={orgId} orgId={orgId}
sites={sites} sites={sites}
onSuccess={() => { onSuccess={() => {