mirror of
https://github.com/fosrl/pangolin.git
synced 2026-03-02 16:56:39 +00:00
Handle delete org and checking org policy
This commit is contained in:
@@ -5,6 +5,8 @@ import { clients, Olm } from "@server/db";
|
|||||||
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||||
|
import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy";
|
||||||
|
import { sendTerminateClient } from "../client/terminate";
|
||||||
|
|
||||||
// Track if the offline checker interval is running
|
// Track if the offline checker interval is running
|
||||||
let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
||||||
@@ -57,6 +59,9 @@ export const startOlmOfflineChecker = (): void => {
|
|||||||
|
|
||||||
// Send a disconnect message to the client if connected
|
// Send a disconnect message to the client if connected
|
||||||
try {
|
try {
|
||||||
|
await sendTerminateClient(offlineClient.clientId); // terminate first
|
||||||
|
// wait a moment to ensure the message is sent
|
||||||
|
await new Promise(resolve => setTimeout(resolve, 1000));
|
||||||
await disconnectClient(offlineClient.olmId);
|
await disconnectClient(offlineClient.olmId);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error(
|
logger.error(
|
||||||
@@ -110,6 +115,36 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
|
|||||||
logger.warn("User ID mismatch for olm ping");
|
logger.warn("User ID mismatch for olm ping");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// get the client
|
||||||
|
const [client] = await db
|
||||||
|
.select()
|
||||||
|
.from(clients)
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
eq(clients.olmId, olm.olmId),
|
||||||
|
eq(clients.userId, olm.userId)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.limit(1);
|
||||||
|
|
||||||
|
if (!client) {
|
||||||
|
logger.warn("Client not found for olm ping");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const policyCheck = await checkOrgAccessPolicy({
|
||||||
|
orgId: client.orgId,
|
||||||
|
userId: olm.userId,
|
||||||
|
session: userToken // this is the user token passed in the message
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!policyCheck.allowed) {
|
||||||
|
logger.warn(
|
||||||
|
`Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!olm.clientId) {
|
if (!olm.clientId) {
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ import {
|
|||||||
} from "@server/lib/ip";
|
} from "@server/lib/ip";
|
||||||
import { generateRemoteSubnets } from "@server/lib/ip";
|
import { generateRemoteSubnets } from "@server/lib/ip";
|
||||||
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
|
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
|
||||||
|
import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy";
|
||||||
|
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||||
|
|
||||||
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||||
logger.info("Handling register olm message!");
|
logger.info("Handling register olm message!");
|
||||||
@@ -45,7 +47,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient } =
|
const { publicKey, relay, olmVersion, orgId, doNotCreateNewClient, token: userToken } =
|
||||||
message.data;
|
message.data;
|
||||||
|
|
||||||
let client: Client | undefined;
|
let client: Client | undefined;
|
||||||
@@ -78,6 +80,35 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!olm.userId) {
|
||||||
|
logger.warn("Olm has no user ID");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { session: userSession, user } =
|
||||||
|
await validateSessionToken(userToken);
|
||||||
|
if (!userSession || !user) {
|
||||||
|
logger.warn("Invalid user session for olm ping");
|
||||||
|
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
|
||||||
|
}
|
||||||
|
if (user.userId !== olm.userId) {
|
||||||
|
logger.warn("User ID mismatch for olm ping");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const policyCheck = await checkOrgAccessPolicy({
|
||||||
|
orgId: orgId,
|
||||||
|
userId: olm.userId,
|
||||||
|
session: userToken // this is the user token passed in the message
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!policyCheck.allowed) {
|
||||||
|
logger.warn(
|
||||||
|
`Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}`
|
`Switching olm client ${olm.olmId} to org ${orgId} for user ${olm.userId}`
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,6 +1,15 @@
|
|||||||
import { Request, Response, NextFunction } from "express";
|
import { Request, Response, NextFunction } from "express";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { db, domains, orgDomains, resources } from "@server/db";
|
import {
|
||||||
|
clients,
|
||||||
|
clientSiteResourcesAssociationsCache,
|
||||||
|
clientSitesAssociationsCache,
|
||||||
|
db,
|
||||||
|
domains,
|
||||||
|
olms,
|
||||||
|
orgDomains,
|
||||||
|
resources
|
||||||
|
} from "@server/db";
|
||||||
import { newts, newtSessions, orgs, sites, userActions } from "@server/db";
|
import { newts, newtSessions, orgs, sites, userActions } from "@server/db";
|
||||||
import { eq, and, inArray, sql } from "drizzle-orm";
|
import { eq, and, inArray, sql } from "drizzle-orm";
|
||||||
import response from "@server/lib/response";
|
import response from "@server/lib/response";
|
||||||
@@ -14,8 +23,8 @@ import { deletePeer } from "../gerbil/peers";
|
|||||||
import { OpenAPITags, registry } from "@server/openApi";
|
import { OpenAPITags, registry } from "@server/openApi";
|
||||||
|
|
||||||
const deleteOrgSchema = z.strictObject({
|
const deleteOrgSchema = z.strictObject({
|
||||||
orgId: z.string()
|
orgId: z.string()
|
||||||
});
|
});
|
||||||
|
|
||||||
export type DeleteOrgResponse = {};
|
export type DeleteOrgResponse = {};
|
||||||
|
|
||||||
@@ -69,41 +78,75 @@ export async function deleteOrg(
|
|||||||
.where(eq(sites.orgId, orgId))
|
.where(eq(sites.orgId, orgId))
|
||||||
.limit(1);
|
.limit(1);
|
||||||
|
|
||||||
|
const orgClients = await db
|
||||||
|
.select()
|
||||||
|
.from(clients)
|
||||||
|
.where(eq(clients.orgId, orgId));
|
||||||
|
|
||||||
const deletedNewtIds: string[] = [];
|
const deletedNewtIds: string[] = [];
|
||||||
|
const olmsToTerminate: string[] = [];
|
||||||
|
|
||||||
await db.transaction(async (trx) => {
|
await db.transaction(async (trx) => {
|
||||||
if (sites) {
|
for (const site of orgSites) {
|
||||||
for (const site of orgSites) {
|
if (site.pubKey) {
|
||||||
if (site.pubKey) {
|
if (site.type == "wireguard") {
|
||||||
if (site.type == "wireguard") {
|
await deletePeer(site.exitNodeId!, site.pubKey);
|
||||||
await deletePeer(site.exitNodeId!, site.pubKey);
|
} else if (site.type == "newt") {
|
||||||
} else if (site.type == "newt") {
|
// get the newt on the site by querying the newt table for siteId
|
||||||
// get the newt on the site by querying the newt table for siteId
|
const [deletedNewt] = await trx
|
||||||
const [deletedNewt] = await trx
|
.delete(newts)
|
||||||
.delete(newts)
|
.where(eq(newts.siteId, site.siteId))
|
||||||
.where(eq(newts.siteId, site.siteId))
|
.returning();
|
||||||
.returning();
|
if (deletedNewt) {
|
||||||
if (deletedNewt) {
|
deletedNewtIds.push(deletedNewt.newtId);
|
||||||
deletedNewtIds.push(deletedNewt.newtId);
|
|
||||||
|
|
||||||
// delete all of the sessions for the newt
|
// delete all of the sessions for the newt
|
||||||
await trx
|
await trx
|
||||||
.delete(newtSessions)
|
.delete(newtSessions)
|
||||||
.where(
|
.where(
|
||||||
eq(
|
eq(newtSessions.newtId, deletedNewt.newtId)
|
||||||
newtSessions.newtId,
|
);
|
||||||
deletedNewt.newtId
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(`Deleting site ${site.siteId}`);
|
|
||||||
await trx
|
|
||||||
.delete(sites)
|
|
||||||
.where(eq(sites.siteId, site.siteId));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.info(`Deleting site ${site.siteId}`);
|
||||||
|
await trx.delete(sites).where(eq(sites.siteId, site.siteId));
|
||||||
|
}
|
||||||
|
for (const client of orgClients) {
|
||||||
|
const [olm] = await trx
|
||||||
|
.select()
|
||||||
|
.from(olms)
|
||||||
|
.where(eq(olms.clientId, client.clientId))
|
||||||
|
.limit(1);
|
||||||
|
|
||||||
|
if (olm) {
|
||||||
|
olmsToTerminate.push(olm.olmId);
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(`Deleting client ${client.clientId}`);
|
||||||
|
await trx
|
||||||
|
.delete(clients)
|
||||||
|
.where(eq(clients.clientId, client.clientId));
|
||||||
|
|
||||||
|
// also delete the associations
|
||||||
|
await trx
|
||||||
|
.delete(clientSiteResourcesAssociationsCache)
|
||||||
|
.where(
|
||||||
|
eq(
|
||||||
|
clientSiteResourcesAssociationsCache.clientId,
|
||||||
|
client.clientId
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
await trx
|
||||||
|
.delete(clientSitesAssociationsCache)
|
||||||
|
.where(
|
||||||
|
eq(
|
||||||
|
clientSitesAssociationsCache.clientId,
|
||||||
|
client.clientId
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const allOrgDomains = await trx
|
const allOrgDomains = await trx
|
||||||
@@ -162,6 +205,18 @@ export async function deleteOrg(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (const olmId of olmsToTerminate) {
|
||||||
|
sendToClient(olmId, {
|
||||||
|
type: "olm/terminate",
|
||||||
|
data: {}
|
||||||
|
}).catch((error) => {
|
||||||
|
logger.error(
|
||||||
|
"Failed to send termination message to olm:",
|
||||||
|
error
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
return response(res, {
|
return response(res, {
|
||||||
data: null,
|
data: null,
|
||||||
success: true,
|
success: true,
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ export async function addUserRole(
|
|||||||
.returning();
|
.returning();
|
||||||
|
|
||||||
// get the client associated with this user in this org
|
// get the client associated with this user in this org
|
||||||
const [orgClient] = await trx
|
const orgClients = await trx
|
||||||
.select()
|
.select()
|
||||||
.from(clients)
|
.from(clients)
|
||||||
.where(
|
.where(
|
||||||
@@ -136,7 +136,7 @@ export async function addUserRole(
|
|||||||
)
|
)
|
||||||
.limit(1);
|
.limit(1);
|
||||||
|
|
||||||
if (orgClient) {
|
for (const orgClient of orgClients) {
|
||||||
// we just changed the user's role, so we need to rebuild client associations and what they have access to
|
// we just changed the user's role, so we need to rebuild client associations and what they have access to
|
||||||
await rebuildClientAssociationsFromClient(orgClient, trx);
|
await rebuildClientAssociationsFromClient(orgClient, trx);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user