Remove siteIds and build associations from user role chnages

This commit is contained in:
Owen
2025-11-06 17:59:34 -08:00
parent bea1c65076
commit ec1f94791a
9 changed files with 669 additions and 459 deletions

View File

@@ -0,0 +1,389 @@
import {
Client,
clients,
clientSites,
db,
exitNodes,
newts,
olms,
roleSiteResources,
Site,
SiteResource,
sites,
Transaction,
userOrgs,
users,
userSiteResources
} from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import {
addPeer as newtAddPeer,
deletePeer as newtDeletePeer
} from "@server/routers/newt/peers";
import {
addPeer as olmAddPeer,
deletePeer as olmDeletePeer
} from "@server/routers/olm/peers";
import { sendToExitNode } from "#dynamic/lib/exitNodes";
import logger from "@server/logger";
export async function rebuildSiteClientAssociations(
siteResource: SiteResource,
trx: Transaction | typeof db = db
): Promise<void> {
const siteId = siteResource.siteId;
// get the site
const [site] = await trx
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
const roleIds = await trx
.select()
.from(roleSiteResources)
.where(
eq(roleSiteResources.siteResourceId, siteResource.siteResourceId)
)
.then((rows) => rows.map((row) => row.roleId));
const directUserIds = await trx
.select()
.from(userSiteResources)
.where(
eq(userSiteResources.siteResourceId, siteResource.siteResourceId)
)
.then((rows) => rows.map((row) => row.userId));
// get all of the users in these roles
const userIdsFromRoles = await trx
.select({
userId: users.userId
})
.from(userOrgs)
.where(inArray(userOrgs.roleId, roleIds))
.then((rows) => rows.map((row) => row.userId));
const allUserIds = Array.from(
new Set([...directUserIds, ...userIdsFromRoles])
);
const allClients = await trx
.select({
clientId: clients.clientId,
pubKey: clients.pubKey,
subnet: clients.subnet
})
.from(clients)
.where(inArray(clients.userId, allUserIds));
const allClientIds = allClients.map((client) => client.clientId);
const existingClientSiteIds = await trx
.select({
clientId: clientSites.clientId
})
.from(clientSites)
.where(eq(clientSites.siteId, siteId))
.then((rows) => rows.map((row) => row.clientId));
const clientSitesToAdd = allClientIds.filter(
(clientId) => !existingClientSiteIds.includes(clientId)
);
const clientSitesToInsert = allClientIds
.filter((clientId) => !existingClientSiteIds.includes(clientId))
.map((clientId) => ({
clientId,
siteId
}));
if (clientSitesToInsert.length > 0) {
await trx.insert(clientSites).values(clientSitesToInsert);
}
// Now remove any client-site associations that should no longer exist
const clientSitesToRemove = existingClientSiteIds.filter(
(clientId) => !allClientIds.includes(clientId)
);
if (clientSitesToRemove.length > 0) {
await trx
.delete(clientSites)
.where(
and(
eq(clientSites.siteId, siteId),
inArray(clientSites.clientId, clientSitesToRemove)
)
);
}
// Now handle the messages to add/remove peers on both the newt and olm sides
await handleMessagesForSiteClients(
site,
siteId,
allClients,
clientSitesToAdd,
clientSitesToRemove,
trx
);
}
async function handleMessagesForSiteClients(
site: Site,
siteId: number,
allClients: {
clientId: number;
pubKey: string | null;
subnet: string | null;
}[],
clientSitesToAdd: number[],
clientSitesToRemove: number[],
trx: Transaction | typeof db = db
): Promise<void> {
if (!site.exitNodeId) {
logger.warn(
`Exit node ID not on site ${site.siteId} so there is no reason to update clients because it must be offline`
);
return;
}
// get the exit node for the site
const [exitNode] = await trx
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
if (!exitNode) {
logger.warn(
`Exit node not found for site ${site.siteId} so there is no reason to update clients because it must be offline`
);
return;
}
if (!site.publicKey) {
logger.warn(
`Site publicKey not set for site ${site.siteId} so cannot add peers to clients`
);
return;
}
const [newt] = await trx
.select({
newtId: newts.newtId
})
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
logger.warn(
`Newt not found for site ${siteId} so cannot add peers to clients`
);
return;
}
let newtJobs: Promise<any>[] = [];
let olmJobs: Promise<any>[] = [];
let exitNodeJobs: Promise<any>[] = [];
for (const client of allClients) {
// UPDATE THE NEWT
if (!client.subnet || !client.pubKey) {
logger.debug("Client subnet, pubKey or endpoint is not set");
continue;
}
// is this an add or a delete?
const isAdd = clientSitesToAdd.includes(client.clientId);
const isDelete = clientSitesToRemove.includes(client.clientId);
if (!isAdd && !isDelete) {
// nothing to do for this client
continue;
}
const [olm] = await trx
.select({
olmId: olms.olmId
})
.from(olms)
.where(eq(olms.clientId, client.clientId))
.limit(1);
if (!olm) {
logger.warn(
`Olm not found for client ${client.clientId} so cannot add/delete peers`
);
continue;
}
if (isDelete) {
newtJobs.push(newtDeletePeer(siteId, client.pubKey, newt.newtId));
olmJobs.push(
olmDeletePeer(
client.clientId,
siteId,
site.publicKey,
olm.olmId
)
);
}
if (isAdd) {
// TODO: WE NEED TO HANDLE THIS BETTER. WE ARE DEFAULTING TO RELAYING FOR NEW SITES
// BUT REALLY WE NEED TO TRACK THE USERS PREFERENCE THAT THEY CHOSE IN THE CLIENTS
// AND TRIGGER A HOLEPUNCH OR SOMETHING TO GET THE ENDPOINT AND HP TO THE NEW SITES
const isRelayed = true;
newtJobs.push(
newtAddPeer(
siteId,
{
publicKey: client.pubKey,
allowedIps: [`${client.subnet.split("/")[-1]}/32`], // we want to only allow from that client
// endpoint: isRelayed ? "" : clientSite.endpoint
endpoint: isRelayed ? "" : "" // we are not HPing yet so no endpoint
},
newt.newtId
)
);
olmJobs.push(
olmAddPeer(
client.clientId,
{
siteId: site.siteId,
endpoint:
isRelayed || !site.endpoint
? `${exitNode.endpoint}:21820`
: site.endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: site.remoteSubnets
},
olm.olmId
)
);
}
exitNodeJobs.push(updateClientSiteDestinations(client, trx));
}
await Promise.all(exitNodeJobs);
await Promise.all(newtJobs); // do the servers first to make sure they are ready?
await Promise.all(olmJobs);
}
interface PeerDestination {
destinationIP: string;
destinationPort: number;
}
// this updates the relay destinations for a client to point to all of the new sites
export async function updateClientSiteDestinations(
client: {
clientId: number;
pubKey: string | null;
subnet: string | null;
},
trx: Transaction | typeof db = db
): Promise<void> {
let exitNodeDestinations: {
reachableAt: string;
exitNodeId: number;
type: string;
name: string;
sourceIp: string;
sourcePort: number;
destinations: PeerDestination[];
}[] = [];
const sitesData = await trx
.select()
.from(sites)
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
.leftJoin(exitNodes, eq(sites.exitNodeId, exitNodes.exitNodeId))
.where(eq(clientSites.clientId, client.clientId));
for (const site of sitesData) {
if (!site.sites.subnet) {
logger.warn(`Site ${site.sites.siteId} has no subnet, skipping`);
continue;
}
if (!site.clientSites.endpoint) {
logger.warn(`Site ${site.sites.siteId} has no endpoint, skipping`);
continue;
}
// find the destinations in the array
let destinations = exitNodeDestinations.find(
(d) => d.reachableAt === site.exitNodes?.reachableAt
);
if (!destinations) {
destinations = {
reachableAt: site.exitNodes?.reachableAt || "",
exitNodeId: site.exitNodes?.exitNodeId || 0,
type: site.exitNodes?.type || "",
name: site.exitNodes?.name || "",
sourceIp: site.clientSites.endpoint.split(":")[0] || "",
sourcePort:
parseInt(site.clientSites.endpoint.split(":")[1]) || 0,
destinations: [
{
destinationIP: site.sites.subnet.split("/")[0],
destinationPort: site.sites.listenPort || 0
}
]
};
} else {
// add to the existing destinations
destinations.destinations.push({
destinationIP: site.sites.subnet.split("/")[0],
destinationPort: site.sites.listenPort || 0
});
}
// update it in the array
exitNodeDestinations = exitNodeDestinations.filter(
(d) => d.reachableAt !== site.exitNodes?.reachableAt
);
exitNodeDestinations.push(destinations);
}
for (const destination of exitNodeDestinations) {
logger.info(
`Updating destinations for exit node at ${destination.reachableAt}`
);
const payload = {
sourceIp: destination.sourceIp,
sourcePort: destination.sourcePort,
destinations: destination.destinations
};
logger.info(
`Payload for update-destinations: ${JSON.stringify(payload, null, 2)}`
);
// Create an ExitNode-like object for sendToExitNode
const exitNodeForComm = {
exitNodeId: destination.exitNodeId,
type: destination.type,
reachableAt: destination.reachableAt,
name: destination.name
} as any; // Using 'as any' since we know sendToExitNode will handle this correctly
await sendToExitNode(exitNodeForComm, {
remoteType: "remoteExitNode/update-destinations",
localPath: "/update-destinations",
method: "POST",
data: payload
});
}
}

View File

@@ -9,15 +9,6 @@ import logger from "@server/logger";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import {
addPeer as newtAddPeer,
deletePeer as newtDeletePeer
} from "../newt/peers";
import {
addPeer as olmAddPeer,
deletePeer as olmDeletePeer
} from "../olm/peers";
import { sendToExitNode } from "#dynamic/lib/exitNodes";
const updateClientParamsSchema = z const updateClientParamsSchema = z
.object({ .object({
@@ -27,10 +18,7 @@ const updateClientParamsSchema = z
const updateClientSchema = z const updateClientSchema = z
.object({ .object({
name: z.string().min(1).max(255).optional(), name: z.string().min(1).max(255).optional()
siteIds: z
.array(z.number().int().positive())
.optional()
}) })
.strict(); .strict();
@@ -54,11 +42,6 @@ registry.registerPath({
responses: {} responses: {}
}); });
interface PeerDestination {
destinationIP: string;
destinationPort: number;
}
export async function updateClient( export async function updateClient(
req: Request, req: Request,
res: Response, res: Response,
@@ -75,7 +58,7 @@ export async function updateClient(
); );
} }
const { name, siteIds } = parsedBody.data; const { name } = parsedBody.data;
const parsedParams = updateClientParamsSchema.safeParse(req.params); const parsedParams = updateClientParamsSchema.safeParse(req.params);
if (!parsedParams.success) { if (!parsedParams.success) {
@@ -105,266 +88,11 @@ export async function updateClient(
); );
} }
let sitesAdded = []; const updatedClient = await db
let sitesRemoved = []; .update(clients)
.set({ name })
// Fetch existing site associations .where(eq(clients.clientId, clientId))
const existingSites = await db .returning();
.select({ siteId: clientSites.siteId })
.from(clientSites)
.where(eq(clientSites.clientId, clientId));
const existingSiteIds = existingSites.map((site) => site.siteId);
const siteIdsToProcess = siteIds || [];
// Determine which sites were added and removed
sitesAdded = siteIdsToProcess.filter(
(siteId) => !existingSiteIds.includes(siteId)
);
sitesRemoved = existingSiteIds.filter(
(siteId) => !siteIdsToProcess.includes(siteId)
);
let updatedClient: Client | undefined = undefined;
let sitesData: any; // TODO: define type somehow from the query below
await db.transaction(async (trx) => {
// Update client name if provided
if (name) {
await trx
.update(clients)
.set({ name })
.where(eq(clients.clientId, clientId));
}
// Update site associations if provided
// Remove sites that are no longer associated
for (const siteId of sitesRemoved) {
await trx
.delete(clientSites)
.where(
and(
eq(clientSites.clientId, clientId),
eq(clientSites.siteId, siteId)
)
);
}
// Add new site associations
for (const siteId of sitesAdded) {
await trx.insert(clientSites).values({
clientId,
siteId
});
}
// Fetch the updated client
[updatedClient] = await trx
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
// get all sites for this client and join with exit nodes with site.exitNodeId
sitesData = await trx
.select()
.from(sites)
.innerJoin(clientSites, eq(sites.siteId, clientSites.siteId))
.leftJoin(exitNodes, eq(sites.exitNodeId, exitNodes.exitNodeId))
.where(eq(clientSites.clientId, client.clientId));
});
logger.info(
`Adding ${sitesAdded.length} new sites to client ${client.clientId}`
);
for (const siteId of sitesAdded) {
if (!client.subnet || !client.pubKey) {
logger.debug("Client subnet, pubKey or endpoint is not set");
continue;
}
// TODO: WE NEED TO HANDLE THIS BETTER. WE ARE DEFAULTING TO RELAYING FOR NEW SITES
// BUT REALLY WE NEED TO TRACK THE USERS PREFERENCE THAT THEY CHOSE IN THE CLIENTS
// AND TRIGGER A HOLEPUNCH OR SOMETHING TO GET THE ENDPOINT AND HP TO THE NEW SITES
const isRelayed = true;
const site = await newtAddPeer(siteId, {
publicKey: client.pubKey,
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
// endpoint: isRelayed ? "" : clientSite.endpoint
endpoint: isRelayed ? "" : "" // we are not HPing yet so no endpoint
});
if (!site) {
logger.debug("Failed to add peer to newt - missing site");
continue;
}
if (!site.endpoint || !site.publicKey) {
logger.debug("Site endpoint or publicKey is not set");
continue;
}
let endpoint;
if (isRelayed) {
if (!site.exitNodeId) {
logger.warn(
`Site ${site.siteId} has no exit node, skipping`
);
return null;
}
// get the exit node for the site
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
.limit(1);
if (!exitNode) {
logger.warn(`Exit node not found for site ${site.siteId}`);
return null;
}
endpoint = `${exitNode.endpoint}:21820`;
} else {
if (!site.endpoint) {
logger.warn(
`Site ${site.siteId} has no endpoint, skipping`
);
return null;
}
endpoint = site.endpoint;
}
await olmAddPeer(client.clientId, {
siteId: site.siteId,
endpoint: endpoint,
publicKey: site.publicKey,
serverIP: site.address,
serverPort: site.listenPort,
remoteSubnets: site.remoteSubnets
});
}
logger.info(
`Removing ${sitesRemoved.length} sites from client ${client.clientId}`
);
for (const siteId of sitesRemoved) {
if (!client.pubKey) {
logger.debug("Client pubKey is not set");
continue;
}
const site = await newtDeletePeer(siteId, client.pubKey);
if (!site) {
logger.debug("Failed to delete peer from newt - missing site");
continue;
}
if (!site.endpoint || !site.publicKey) {
logger.debug("Site endpoint or publicKey is not set");
continue;
}
await olmDeletePeer(client.clientId, site.siteId, site.publicKey);
}
if (!updatedClient || !sitesData) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Failed to update client`
)
);
}
let exitNodeDestinations: {
reachableAt: string;
exitNodeId: number;
type: string;
name: string;
sourceIp: string;
sourcePort: number;
destinations: PeerDestination[];
}[] = [];
for (const site of sitesData) {
if (!site.sites.subnet) {
logger.warn(
`Site ${site.sites.siteId} has no subnet, skipping`
);
continue;
}
if (!site.clientSites.endpoint) {
logger.warn(
`Site ${site.sites.siteId} has no endpoint, skipping`
);
continue;
}
// find the destinations in the array
let destinations = exitNodeDestinations.find(
(d) => d.reachableAt === site.exitNodes?.reachableAt
);
if (!destinations) {
destinations = {
reachableAt: site.exitNodes?.reachableAt || "",
exitNodeId: site.exitNodes?.exitNodeId || 0,
type: site.exitNodes?.type || "",
name: site.exitNodes?.name || "",
sourceIp: site.clientSites.endpoint.split(":")[0] || "",
sourcePort:
parseInt(site.clientSites.endpoint.split(":")[1]) || 0,
destinations: [
{
destinationIP: site.sites.subnet.split("/")[0],
destinationPort: site.sites.listenPort || 0
}
]
};
} else {
// add to the existing destinations
destinations.destinations.push({
destinationIP: site.sites.subnet.split("/")[0],
destinationPort: site.sites.listenPort || 0
});
}
// update it in the array
exitNodeDestinations = exitNodeDestinations.filter(
(d) => d.reachableAt !== site.exitNodes?.reachableAt
);
exitNodeDestinations.push(destinations);
}
for (const destination of exitNodeDestinations) {
logger.info(
`Updating destinations for exit node at ${destination.reachableAt}`
);
const payload = {
sourceIp: destination.sourceIp,
sourcePort: destination.sourcePort,
destinations: destination.destinations
};
logger.info(
`Payload for update-destinations: ${JSON.stringify(payload, null, 2)}`
);
// Create an ExitNode-like object for sendToExitNode
const exitNodeForComm = {
exitNodeId: destination.exitNodeId,
type: destination.type,
reachableAt: destination.reachableAt,
name: destination.name
} as any; // Using 'as any' since we know sendToExitNode will handle this correctly
await sendToExitNode(exitNodeForComm, {
remoteType: "remoteExitNode/update-destinations",
localPath: "/update-destinations",
method: "POST",
data: payload
});
}
return response(res, { return response(res, {
data: updatedClient, data: updatedClient,

View File

@@ -1,4 +1,4 @@
import { db } from "@server/db"; import { db, Site } from "@server/db";
import { newts, sites } from "@server/db"; import { newts, sites } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { sendToClient } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
@@ -10,65 +10,74 @@ export async function addPeer(
publicKey: string; publicKey: string;
allowedIps: string[]; allowedIps: string[];
endpoint: string; endpoint: string;
} },
newtId?: string
) { ) {
const [site] = await db let site: Site | null = null;
.select() if (!newtId) {
.from(sites) [site] = await db
.where(eq(sites.siteId, siteId)) .select()
.limit(1); .from(sites)
if (!site) { .where(eq(sites.siteId, siteId))
throw new Error(`Exit node with ID ${siteId} not found`); .limit(1);
if (!site) {
throw new Error(`Exit node with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Site found for site ${siteId}`);
}
newtId = newt.newtId;
} }
// get the newt on the site await sendToClient(newtId, {
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Site found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: "newt/wg/peer/add", type: "newt/wg/peer/add",
data: peer data: peer
}); });
logger.info(`Added peer ${peer.publicKey} to newt ${newt.newtId}`); logger.info(`Added peer ${peer.publicKey} to newt ${newtId}`);
return site; return site;
} }
export async function deletePeer(siteId: number, publicKey: string) { export async function deletePeer(siteId: number, publicKey: string, newtId?: string) {
const [site] = await db let site: Site | null = null;
.select() if (!newtId) {
.from(sites) [site] = await db
.where(eq(sites.siteId, siteId)) .select()
.limit(1); .from(sites)
if (!site) { .where(eq(sites.siteId, siteId))
throw new Error(`Site with ID ${siteId} not found`); .limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
newtId = newt.newtId;
} }
// get the newt on the site await sendToClient(newtId, {
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: "newt/wg/peer/remove", type: "newt/wg/peer/remove",
data: { data: {
publicKey publicKey
} }
}); });
logger.info(`Deleted peer ${publicKey} from newt ${newt.newtId}`); logger.info(`Deleted peer ${publicKey} from newt ${newtId}`);
return site; return site;
} }
@@ -79,28 +88,33 @@ export async function updatePeer(
peer: { peer: {
allowedIps?: string[]; allowedIps?: string[];
endpoint?: string; endpoint?: string;
} },
newtId?: string
) { ) {
const [site] = await db let site: Site | null = null;
.select() if (!newtId) {
.from(sites) [site] = await db
.where(eq(sites.siteId, siteId)) .select()
.limit(1); .from(sites)
if (!site) { .where(eq(sites.siteId, siteId))
throw new Error(`Site with ID ${siteId} not found`); .limit(1);
if (!site) {
throw new Error(`Site with ID ${siteId} not found`);
}
// get the newt on the site
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
newtId = newt.newtId;
} }
// get the newt on the site await sendToClient(newtId, {
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId))
.limit(1);
if (!newt) {
throw new Error(`Newt not found for site ${siteId}`);
}
sendToClient(newt.newtId, {
type: "newt/wg/peer/update", type: "newt/wg/peer/update",
data: { data: {
publicKey, publicKey,
@@ -108,7 +122,7 @@ export async function updatePeer(
} }
}); });
logger.info(`Updated peer ${publicKey} on newt ${newt.newtId}`); logger.info(`Updated peer ${publicKey} on newt ${newtId}`);
return site; return site;
} }

View File

@@ -388,6 +388,11 @@ async function getOrCreateOrgClient(
clientId: newClient.clientId clientId: newClient.clientId
}); });
await trx.insert(userClients).values({ // we also want to make sure that the user can see their own client if they are not an admin
userId,
clientId: newClient.clientId
});
if (userOrg.roleId != adminRole.roleId) { if (userOrg.roleId != adminRole.roleId) {
// make sure the user can access the client // make sure the user can access the client
trx.insert(userClients).values({ trx.insert(userClients).values({

View File

@@ -13,18 +13,22 @@ export async function addPeer(
serverIP: string | null; serverIP: string | null;
serverPort: number | null; serverPort: number | null;
remoteSubnets: string | null; // optional, comma-separated list of subnets that this site can access remoteSubnets: string | null; // optional, comma-separated list of subnets that this site can access
} },
olmId?: string
) { ) {
const [olm] = await db if (!olmId) {
.select() const [olm] = await db
.from(olms) .select()
.where(eq(olms.clientId, clientId)) .from(olms)
.limit(1); .where(eq(olms.clientId, clientId))
if (!olm) { .limit(1);
throw new Error(`Olm with ID ${clientId} not found`); if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
olmId = olm.olmId;
} }
await sendToClient(olm.olmId, { await sendToClient(olmId, {
type: "olm/wg/peer/add", type: "olm/wg/peer/add",
data: { data: {
siteId: peer.siteId, siteId: peer.siteId,
@@ -36,20 +40,28 @@ export async function addPeer(
} }
}); });
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`); logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`);
} }
export async function deletePeer(clientId: number, siteId: number, publicKey: string) { export async function deletePeer(
const [olm] = await db clientId: number,
.select() siteId: number,
.from(olms) publicKey: string,
.where(eq(olms.clientId, clientId)) olmId?: string
.limit(1); ) {
if (!olm) { if (!olmId) {
throw new Error(`Olm with ID ${clientId} not found`); const [olm] = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
olmId = olm.olmId;
} }
await sendToClient(olm.olmId, { await sendToClient(olmId, {
type: "olm/wg/peer/remove", type: "olm/wg/peer/remove",
data: { data: {
publicKey, publicKey,
@@ -57,7 +69,7 @@ export async function deletePeer(clientId: number, siteId: number, publicKey: st
} }
}); });
logger.info(`Deleted peer ${publicKey} from olm ${olm.olmId}`); logger.info(`Deleted peer ${publicKey} from olm ${olmId}`);
} }
export async function updatePeer( export async function updatePeer(
@@ -69,18 +81,22 @@ export async function updatePeer(
serverIP: string | null; serverIP: string | null;
serverPort: number | null; serverPort: number | null;
remoteSubnets?: string | null; // optional, comma-separated list of subnets that remoteSubnets?: string | null; // optional, comma-separated list of subnets that
} },
olmId?: string
) { ) {
const [olm] = await db if (!olmId) {
.select() const [olm] = await db
.from(olms) .select()
.where(eq(olms.clientId, clientId)) .from(olms)
.limit(1); .where(eq(olms.clientId, clientId))
if (!olm) { .limit(1);
throw new Error(`Olm with ID ${clientId} not found`); if (!olm) {
throw new Error(`Olm with ID ${clientId} not found`);
}
olmId = olm.olmId;
} }
await sendToClient(olm.olmId, { await sendToClient(olmId, {
type: "olm/wg/peer/update", type: "olm/wg/peer/update",
data: { data: {
siteId: peer.siteId, siteId: peer.siteId,
@@ -92,5 +108,5 @@ export async function updatePeer(
} }
}); });
logger.info(`Added peer ${peer.publicKey} to olm ${olm.olmId}`); logger.info(`Added peer ${peer.publicKey} to olm ${olmId}`);
} }

View File

@@ -11,6 +11,7 @@ import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { addTargets } from "../client/targets"; import { addTargets } from "../client/targets";
import { getUniqueSiteResourceName } from "@server/db/names"; import { getUniqueSiteResourceName } from "@server/db/names";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
const createSiteResourceParamsSchema = z const createSiteResourceParamsSchema = z
.object({ .object({
@@ -29,7 +30,8 @@ const createSiteResourceSchema = z
destination: z.string().min(1), destination: z.string().min(1),
enabled: z.boolean().default(true), enabled: z.boolean().default(true),
alias: z.string().optional() alias: z.string().optional()
}).strict() })
.strict()
.refine( .refine(
(data) => { (data) => {
if (data.mode === "port") { if (data.mode === "port") {
@@ -145,61 +147,75 @@ export async function createSiteResource(
const niceId = await getUniqueSiteResourceName(orgId); const niceId = await getUniqueSiteResourceName(orgId);
// Create the site resource let newSiteResource: SiteResource | undefined;
const [newSiteResource] = await db await db.transaction(async (trx) => {
.insert(siteResources) // Create the site resource
.values({ [newSiteResource] = await trx
siteId, .insert(siteResources)
niceId, .values({
orgId, siteId,
name, niceId,
mode, orgId,
protocol: mode === "port" ? protocol : null, name,
proxyPort: mode === "port" ? proxyPort : null, mode,
destinationPort: mode === "port" ? destinationPort : null, protocol: mode === "port" ? protocol : null,
destination, proxyPort: mode === "port" ? proxyPort : null,
enabled, destinationPort: mode === "port" ? destinationPort : null,
alias: alias || null destination,
}) enabled,
.returning(); alias: alias || null
})
.returning();
const adminRole = await db const [adminRole] = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (adminRole.length === 0) {
return next(
createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
);
}
await db.insert(roleSiteResources).values({
roleId: adminRole[0].roleId,
siteResourceId: newSiteResource.siteResourceId
});
// Only add targets for port mode
if (mode === "port" && protocol && proxyPort && destinationPort) {
const [newt] = await db
.select() .select()
.from(newts) .from(roles)
.where(eq(newts.siteId, site.siteId)) .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1); .limit(1);
if (!newt) { if (!adminRole) {
return next( return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found") createHttpError(HttpCode.NOT_FOUND, `Admin role not found`)
); );
} }
await addTargets( await trx.insert(roleSiteResources).values({
newt.newtId, roleId: adminRole.roleId,
destination, siteResourceId: newSiteResource.siteResourceId
destinationPort, });
protocol,
proxyPort // Only add targets for port mode
if (mode === "port" && protocol && proxyPort && destinationPort) {
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found")
);
}
await addTargets(
newt.newtId,
destination,
destinationPort,
protocol,
proxyPort
);
}
await rebuildSiteClientAssociations(newSiteResource, trx); // we need to call this because we added to the admin role
});
if (!newSiteResource) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Site resource creation failed"
)
); );
} }

View File

@@ -10,10 +10,14 @@ import { fromError } from "zod-validation-error";
import logger from "@server/logger"; import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { removeTargets } from "../client/targets"; import { removeTargets } from "../client/targets";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
const deleteSiteResourceParamsSchema = z const deleteSiteResourceParamsSchema = z
.object({ .object({
siteResourceId: z.string().transform(Number).pipe(z.number().int().positive()), siteResourceId: z
.string()
.transform(Number)
.pipe(z.number().int().positive()),
siteId: z.string().transform(Number).pipe(z.number().int().positive()), siteId: z.string().transform(Number).pipe(z.number().int().positive()),
orgId: z.string() orgId: z.string()
}) })
@@ -40,7 +44,9 @@ export async function deleteSiteResource(
next: NextFunction next: NextFunction
): Promise<any> { ): Promise<any> {
try { try {
const parsedParams = deleteSiteResourceParamsSchema.safeParse(req.params); const parsedParams = deleteSiteResourceParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) { if (!parsedParams.success) {
return next( return next(
createHttpError( createHttpError(
@@ -66,53 +72,61 @@ export async function deleteSiteResource(
const [existingSiteResource] = await db const [existingSiteResource] = await db
.select() .select()
.from(siteResources) .from(siteResources)
.where(and( .where(and(eq(siteResources.siteResourceId, siteResourceId)))
eq(siteResources.siteResourceId, siteResourceId),
eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
))
.limit(1); .limit(1);
if (!existingSiteResource) { if (!existingSiteResource) {
return next( return next(
createHttpError( createHttpError(HttpCode.NOT_FOUND, "Site resource not found")
HttpCode.NOT_FOUND,
"Site resource not found"
)
); );
} }
// Delete the site resource await db.transaction(async (trx) => {
await db // Delete the site resource
.delete(siteResources) await trx
.where(and( .delete(siteResources)
eq(siteResources.siteResourceId, siteResourceId), .where(
eq(siteResources.siteId, siteId), and(
eq(siteResources.orgId, orgId) eq(siteResources.siteResourceId, siteResourceId),
)); eq(siteResources.siteId, siteId),
eq(siteResources.orgId, orgId)
)
);
// Only remove targets for port mode // Only remove targets for port mode
if (existingSiteResource.mode === "port" && existingSiteResource.protocol && existingSiteResource.proxyPort && existingSiteResource.destinationPort) { if (
const [newt] = await db existingSiteResource.mode === "port" &&
.select() existingSiteResource.protocol &&
.from(newts) existingSiteResource.proxyPort &&
.where(eq(newts.siteId, site.siteId)) existingSiteResource.destinationPort
.limit(1); ) {
const [newt] = await trx
.select()
.from(newts)
.where(eq(newts.siteId, site.siteId))
.limit(1);
if (!newt) { if (!newt) {
return next(createHttpError(HttpCode.NOT_FOUND, "Newt not found")); return next(
createHttpError(HttpCode.NOT_FOUND, "Newt not found")
);
}
await removeTargets(
newt.newtId,
existingSiteResource.destination,
existingSiteResource.destinationPort,
existingSiteResource.protocol,
existingSiteResource.proxyPort
);
} }
await removeTargets( await rebuildSiteClientAssociations(existingSiteResource, trx);
newt.newtId, });
existingSiteResource.destination,
existingSiteResource.destinationPort,
existingSiteResource.protocol,
existingSiteResource.proxyPort
);
}
logger.info(`Deleted site resource ${siteResourceId} for site ${siteId}`); logger.info(
`Deleted site resource ${siteResourceId} for site ${siteId}`
);
return response(res, { return response(res, {
data: { message: "Site resource deleted successfully" }, data: { message: "Site resource deleted successfully" },
@@ -123,6 +137,11 @@ export async function deleteSiteResource(
}); });
} catch (error) { } catch (error) {
logger.error("Error deleting site resource:", error); logger.error("Error deleting site resource:", error);
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to delete site resource")); return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to delete site resource"
)
);
} }
} }

View File

@@ -9,6 +9,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { eq, and, ne } from "drizzle-orm"; import { eq, and, ne } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
const setSiteResourceRolesBodySchema = z const setSiteResourceRolesBodySchema = z
.object({ .object({
@@ -62,7 +63,9 @@ export async function setSiteResourceRoles(
const { roleIds } = parsedBody.data; const { roleIds } = parsedBody.data;
const parsedParams = setSiteResourceRolesParamsSchema.safeParse(req.params); const parsedParams = setSiteResourceRolesParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) { if (!parsedParams.success) {
return next( return next(
createHttpError( createHttpError(
@@ -136,6 +139,8 @@ export async function setSiteResourceRoles(
.returning() .returning()
) )
); );
await rebuildSiteClientAssociations(siteResource, trx);
}); });
return response(res, { return response(res, {
@@ -152,4 +157,3 @@ export async function setSiteResourceRoles(
); );
} }
} }

View File

@@ -1,6 +1,6 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { z } from "zod"; import { z } from "zod";
import { db } from "@server/db"; import { db, siteResources } from "@server/db";
import { userSiteResources } from "@server/db"; import { userSiteResources } from "@server/db";
import response from "@server/lib/response"; import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode"; import HttpCode from "@server/types/HttpCode";
@@ -9,6 +9,7 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi"; import { OpenAPITags, registry } from "@server/openApi";
import { rebuildSiteClientAssociations } from "@server/lib/rebuildSiteClientAssociations";
const setSiteResourceUsersBodySchema = z const setSiteResourceUsersBodySchema = z
.object({ .object({
@@ -74,6 +75,22 @@ export async function setSiteResourceUsers(
const { siteResourceId } = parsedParams.data; const { siteResourceId } = parsedParams.data;
// get the site resource
const [siteResource] = await db
.select()
.from(siteResources)
.where(eq(siteResources.siteResourceId, siteResourceId))
.limit(1);
if (!siteResource) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Site resource not found"
)
);
}
await db.transaction(async (trx) => { await db.transaction(async (trx) => {
await trx await trx
.delete(userSiteResources) .delete(userSiteResources)
@@ -87,6 +104,8 @@ export async function setSiteResourceUsers(
.returning() .returning()
) )
); );
await rebuildSiteClientAssociations(siteResource, trx);
}); });
return response(res, { return response(res, {