mirror of
https://github.com/fosrl/pangolin.git
synced 2026-02-26 23:06:37 +00:00
Merge branch 'dev' into feat/login-page-customization
This commit is contained in:
@@ -19,8 +19,14 @@ import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { queryAccessAuditLogsParams, queryAccessAuditLogsQuery, queryAccess } from "./queryAccessAuditLog";
|
||||
import {
|
||||
queryAccessAuditLogsParams,
|
||||
queryAccessAuditLogsQuery,
|
||||
queryAccess,
|
||||
countAccessQuery
|
||||
} from "./queryAccessAuditLog";
|
||||
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
|
||||
import { MAX_EXPORT_LIMIT } from "@server/routers/auditLogs";
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
@@ -61,16 +67,28 @@ export async function exportAccessAuditLogs(
|
||||
}
|
||||
|
||||
const data = { ...parsedQuery.data, ...parsedParams.data };
|
||||
const [{ count }] = await countAccessQuery(data);
|
||||
if (count > MAX_EXPORT_LIMIT) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Export limit exceeded. Your selection contains ${count} rows, but the maximum is ${MAX_EXPORT_LIMIT} rows. Please select a shorter time range to reduce the data.`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const baseQuery = queryAccess(data);
|
||||
|
||||
const log = await baseQuery.limit(data.limit).offset(data.offset);
|
||||
|
||||
const csvData = generateCSV(log);
|
||||
|
||||
res.setHeader('Content-Type', 'text/csv');
|
||||
res.setHeader('Content-Disposition', `attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`);
|
||||
|
||||
|
||||
res.setHeader("Content-Type", "text/csv");
|
||||
res.setHeader(
|
||||
"Content-Disposition",
|
||||
`attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`
|
||||
);
|
||||
|
||||
return res.send(csvData);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
@@ -78,4 +96,4 @@ export async function exportAccessAuditLogs(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,8 +19,14 @@ import createHttpError from "http-errors";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { queryActionAuditLogsParams, queryActionAuditLogsQuery, queryAction } from "./queryActionAuditLog";
|
||||
import {
|
||||
queryActionAuditLogsParams,
|
||||
queryActionAuditLogsQuery,
|
||||
queryAction,
|
||||
countActionQuery
|
||||
} from "./queryActionAuditLog";
|
||||
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
|
||||
import { MAX_EXPORT_LIMIT } from "@server/routers/auditLogs";
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
@@ -60,17 +66,29 @@ export async function exportActionAuditLogs(
|
||||
);
|
||||
}
|
||||
|
||||
const data = { ...parsedQuery.data, ...parsedParams.data };
|
||||
const data = { ...parsedQuery.data, ...parsedParams.data };
|
||||
const [{ count }] = await countActionQuery(data);
|
||||
if (count > MAX_EXPORT_LIMIT) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Export limit exceeded. Your selection contains ${count} rows, but the maximum is ${MAX_EXPORT_LIMIT} rows. Please select a shorter time range to reduce the data.`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const baseQuery = queryAction(data);
|
||||
|
||||
const log = await baseQuery.limit(data.limit).offset(data.offset);
|
||||
|
||||
const csvData = generateCSV(log);
|
||||
|
||||
res.setHeader('Content-Type', 'text/csv');
|
||||
res.setHeader('Content-Disposition', `attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`);
|
||||
|
||||
|
||||
res.setHeader("Content-Type", "text/csv");
|
||||
res.setHeader(
|
||||
"Content-Disposition",
|
||||
`attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`
|
||||
);
|
||||
|
||||
return res.send(csvData);
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
@@ -78,4 +96,4 @@ export async function exportActionAuditLogs(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,4 +14,4 @@
|
||||
export * from "./queryActionAuditLog";
|
||||
export * from "./exportActionAuditLog";
|
||||
export * from "./queryAccessAuditLog";
|
||||
export * from "./exportAccessAuditLog";
|
||||
export * from "./exportAccessAuditLog";
|
||||
|
||||
@@ -24,6 +24,7 @@ import { fromError } from "zod-validation-error";
|
||||
import { QueryAccessAuditLogResponse } from "@server/routers/auditLogs/types";
|
||||
import response from "@server/lib/response";
|
||||
import logger from "@server/logger";
|
||||
import { getSevenDaysAgo } from "@app/lib/getSevenDaysAgo";
|
||||
|
||||
export const queryAccessAuditLogsQuery = z.object({
|
||||
// iso string just validate its a parseable date
|
||||
@@ -32,7 +33,14 @@ export const queryAccessAuditLogsQuery = z.object({
|
||||
.refine((val) => !isNaN(Date.parse(val)), {
|
||||
error: "timeStart must be a valid ISO date string"
|
||||
})
|
||||
.transform((val) => Math.floor(new Date(val).getTime() / 1000)),
|
||||
.transform((val) => Math.floor(new Date(val).getTime() / 1000))
|
||||
.prefault(() => getSevenDaysAgo().toISOString())
|
||||
.openapi({
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
description:
|
||||
"Start time as ISO date string (defaults to 7 days ago)"
|
||||
}),
|
||||
timeEnd: z
|
||||
.string()
|
||||
.refine((val) => !isNaN(Date.parse(val)), {
|
||||
@@ -44,7 +52,8 @@ export const queryAccessAuditLogsQuery = z.object({
|
||||
.openapi({
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
description: "End time as ISO date string (defaults to current time)"
|
||||
description:
|
||||
"End time as ISO date string (defaults to current time)"
|
||||
}),
|
||||
action: z
|
||||
.union([z.boolean(), z.string()])
|
||||
@@ -181,9 +190,15 @@ async function queryUniqueFilterAttributes(
|
||||
.where(baseConditions);
|
||||
|
||||
return {
|
||||
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
|
||||
resources: uniqueResources.filter((row): row is { id: number; name: string | null } => row.id !== null),
|
||||
locations: uniqueLocations.map(row => row.locations).filter((location): location is string => location !== null)
|
||||
actors: uniqueActors
|
||||
.map((row) => row.actor)
|
||||
.filter((actor): actor is string => actor !== null),
|
||||
resources: uniqueResources.filter(
|
||||
(row): row is { id: number; name: string | null } => row.id !== null
|
||||
),
|
||||
locations: uniqueLocations
|
||||
.map((row) => row.locations)
|
||||
.filter((location): location is string => location !== null)
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import { fromError } from "zod-validation-error";
|
||||
import { QueryActionAuditLogResponse } from "@server/routers/auditLogs/types";
|
||||
import response from "@server/lib/response";
|
||||
import logger from "@server/logger";
|
||||
import { getSevenDaysAgo } from "@app/lib/getSevenDaysAgo";
|
||||
|
||||
export const queryActionAuditLogsQuery = z.object({
|
||||
// iso string just validate its a parseable date
|
||||
@@ -32,7 +33,14 @@ export const queryActionAuditLogsQuery = z.object({
|
||||
.refine((val) => !isNaN(Date.parse(val)), {
|
||||
error: "timeStart must be a valid ISO date string"
|
||||
})
|
||||
.transform((val) => Math.floor(new Date(val).getTime() / 1000)),
|
||||
.transform((val) => Math.floor(new Date(val).getTime() / 1000))
|
||||
.prefault(() => getSevenDaysAgo().toISOString())
|
||||
.openapi({
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
description:
|
||||
"Start time as ISO date string (defaults to 7 days ago)"
|
||||
}),
|
||||
timeEnd: z
|
||||
.string()
|
||||
.refine((val) => !isNaN(Date.parse(val)), {
|
||||
@@ -44,7 +52,8 @@ export const queryActionAuditLogsQuery = z.object({
|
||||
.openapi({
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
description: "End time as ISO date string (defaults to current time)"
|
||||
description:
|
||||
"End time as ISO date string (defaults to current time)"
|
||||
}),
|
||||
action: z.string().optional(),
|
||||
actorType: z.string().optional(),
|
||||
@@ -68,8 +77,9 @@ export const queryActionAuditLogsParams = z.object({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export const queryActionAuditLogsCombined =
|
||||
queryActionAuditLogsQuery.merge(queryActionAuditLogsParams);
|
||||
export const queryActionAuditLogsCombined = queryActionAuditLogsQuery.merge(
|
||||
queryActionAuditLogsParams
|
||||
);
|
||||
type Q = z.infer<typeof queryActionAuditLogsCombined>;
|
||||
|
||||
function getWhere(data: Q) {
|
||||
@@ -78,7 +88,9 @@ function getWhere(data: Q) {
|
||||
lt(actionAuditLog.timestamp, data.timeEnd),
|
||||
eq(actionAuditLog.orgId, data.orgId),
|
||||
data.actor ? eq(actionAuditLog.actor, data.actor) : undefined,
|
||||
data.actorType ? eq(actionAuditLog.actorType, data.actorType) : undefined,
|
||||
data.actorType
|
||||
? eq(actionAuditLog.actorType, data.actorType)
|
||||
: undefined,
|
||||
data.actorId ? eq(actionAuditLog.actorId, data.actorId) : undefined,
|
||||
data.action ? eq(actionAuditLog.action, data.action) : undefined
|
||||
);
|
||||
@@ -135,8 +147,12 @@ async function queryUniqueFilterAttributes(
|
||||
.where(baseConditions);
|
||||
|
||||
return {
|
||||
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
|
||||
actions: uniqueActions.map(row => row.action).filter((action): action is string => action !== null),
|
||||
actors: uniqueActors
|
||||
.map((row) => row.actor)
|
||||
.filter((actor): actor is string => actor !== null),
|
||||
actions: uniqueActions
|
||||
.map((row) => row.action)
|
||||
.filter((action): action is string => action !== null)
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -13,4 +13,4 @@
|
||||
|
||||
export * from "./transferSession";
|
||||
export * from "./getSessionTransferToken";
|
||||
export * from "./quickStart";
|
||||
export * from "./quickStart";
|
||||
|
||||
@@ -395,7 +395,8 @@ export async function quickStart(
|
||||
.values({
|
||||
targetId: newTarget[0].targetId,
|
||||
hcEnabled: false
|
||||
}).returning();
|
||||
})
|
||||
.returning();
|
||||
|
||||
// add the new target to the targetIps array
|
||||
targetIps.push(`${ip}/32`);
|
||||
@@ -406,7 +407,12 @@ export async function quickStart(
|
||||
.where(eq(newts.siteId, siteId!))
|
||||
.limit(1);
|
||||
|
||||
await addTargets(newt.newtId, newTarget, newHealthcheck, resource.protocol);
|
||||
await addTargets(
|
||||
newt.newtId,
|
||||
newTarget,
|
||||
newHealthcheck,
|
||||
resource.protocol
|
||||
);
|
||||
|
||||
// Set resource pincode if provided
|
||||
if (pincode) {
|
||||
|
||||
@@ -26,8 +26,8 @@ import { getLineItems, getStandardFeaturePriceSet } from "@server/lib/billing";
|
||||
import { getTierPriceSet, TierId } from "@server/lib/billing/tiers";
|
||||
|
||||
const createCheckoutSessionSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export async function createCheckoutSession(
|
||||
req: Request,
|
||||
@@ -72,7 +72,7 @@ export async function createCheckoutSession(
|
||||
billing_address_collection: "required",
|
||||
line_items: [
|
||||
{
|
||||
price: standardTierPrice, // Use the standard tier
|
||||
price: standardTierPrice, // Use the standard tier
|
||||
quantity: 1
|
||||
},
|
||||
...getLineItems(getStandardFeaturePriceSet())
|
||||
|
||||
@@ -24,8 +24,8 @@ import { fromError } from "zod-validation-error";
|
||||
import stripe from "#private/lib/stripe";
|
||||
|
||||
const createPortalSessionSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export async function createPortalSession(
|
||||
req: Request,
|
||||
|
||||
@@ -34,8 +34,8 @@ import {
|
||||
} from "@server/db";
|
||||
|
||||
const getOrgSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
|
||||
@@ -28,8 +28,8 @@ import { FeatureId } from "@server/lib/billing";
|
||||
import { GetOrgUsageResponse } from "@server/routers/billing/types";
|
||||
|
||||
const getOrgSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
@@ -78,11 +78,23 @@ export async function getOrgUsage(
|
||||
// Get usage for org
|
||||
const usageData = [];
|
||||
|
||||
const siteUptime = await usageService.getUsage(orgId, FeatureId.SITE_UPTIME);
|
||||
const siteUptime = await usageService.getUsage(
|
||||
orgId,
|
||||
FeatureId.SITE_UPTIME
|
||||
);
|
||||
const users = await usageService.getUsageDaily(orgId, FeatureId.USERS);
|
||||
const domains = await usageService.getUsageDaily(orgId, FeatureId.DOMAINS);
|
||||
const remoteExitNodes = await usageService.getUsageDaily(orgId, FeatureId.REMOTE_EXIT_NODES);
|
||||
const egressData = await usageService.getUsage(orgId, FeatureId.EGRESS_DATA_MB);
|
||||
const domains = await usageService.getUsageDaily(
|
||||
orgId,
|
||||
FeatureId.DOMAINS
|
||||
);
|
||||
const remoteExitNodes = await usageService.getUsageDaily(
|
||||
orgId,
|
||||
FeatureId.REMOTE_EXIT_NODES
|
||||
);
|
||||
const egressData = await usageService.getUsage(
|
||||
orgId,
|
||||
FeatureId.EGRESS_DATA_MB
|
||||
);
|
||||
|
||||
if (siteUptime) {
|
||||
usageData.push(siteUptime);
|
||||
@@ -100,7 +112,8 @@ export async function getOrgUsage(
|
||||
usageData.push(remoteExitNodes);
|
||||
}
|
||||
|
||||
const orgLimits = await db.select()
|
||||
const orgLimits = await db
|
||||
.select()
|
||||
.from(limits)
|
||||
.where(eq(limits.orgId, orgId));
|
||||
|
||||
|
||||
@@ -31,9 +31,7 @@ export async function handleCustomerDeleted(
|
||||
return;
|
||||
}
|
||||
|
||||
await db
|
||||
.delete(customers)
|
||||
.where(eq(customers.customerId, customer.id));
|
||||
await db.delete(customers).where(eq(customers.customerId, customer.id));
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error handling customer created event for ID ${customer.id}:`,
|
||||
|
||||
@@ -12,7 +12,14 @@
|
||||
*/
|
||||
|
||||
import Stripe from "stripe";
|
||||
import { subscriptions, db, subscriptionItems, customers, userOrgs, users } from "@server/db";
|
||||
import {
|
||||
subscriptions,
|
||||
db,
|
||||
subscriptionItems,
|
||||
customers,
|
||||
userOrgs,
|
||||
users
|
||||
} from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
|
||||
@@ -43,7 +50,6 @@ export async function handleSubscriptionDeleted(
|
||||
.delete(subscriptionItems)
|
||||
.where(eq(subscriptionItems.subscriptionId, subscription.id));
|
||||
|
||||
|
||||
// Lookup customer to get orgId
|
||||
const [customer] = await db
|
||||
.select()
|
||||
@@ -58,10 +64,7 @@ export async function handleSubscriptionDeleted(
|
||||
return;
|
||||
}
|
||||
|
||||
await handleSubscriptionLifesycle(
|
||||
customer.orgId,
|
||||
subscription.status
|
||||
);
|
||||
await handleSubscriptionLifesycle(customer.orgId, subscription.status);
|
||||
|
||||
const [orgUserRes] = await db
|
||||
.select()
|
||||
|
||||
@@ -15,4 +15,4 @@ export * from "./createCheckoutSession";
|
||||
export * from "./createPortalSession";
|
||||
export * from "./getOrgSubscription";
|
||||
export * from "./getOrgUsage";
|
||||
export * from "./internalGetOrgTier";
|
||||
export * from "./internalGetOrgTier";
|
||||
|
||||
@@ -22,8 +22,8 @@ import { getOrgTierData } from "#private/lib/billing";
|
||||
import { GetOrgTierResponse } from "@server/routers/billing/types";
|
||||
|
||||
const getOrgSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export async function getOrgTier(
|
||||
req: Request,
|
||||
|
||||
@@ -11,11 +11,18 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { freeLimitSet, limitsService, subscribedLimitSet } from "@server/lib/billing";
|
||||
import {
|
||||
freeLimitSet,
|
||||
limitsService,
|
||||
subscribedLimitSet
|
||||
} from "@server/lib/billing";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import logger from "@server/logger";
|
||||
|
||||
export async function handleSubscriptionLifesycle(orgId: string, status: string) {
|
||||
export async function handleSubscriptionLifesycle(
|
||||
orgId: string,
|
||||
status: string
|
||||
) {
|
||||
switch (status) {
|
||||
case "active":
|
||||
await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet);
|
||||
@@ -42,4 +49,4 @@ export async function handleSubscriptionLifesycle(orgId: string, status: string)
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,12 +32,13 @@ export async function billingWebhookHandler(
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
let event: Stripe.Event = req.body;
|
||||
const endpointSecret = privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
|
||||
const endpointSecret =
|
||||
privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
|
||||
if (!endpointSecret) {
|
||||
logger.warn("Stripe webhook secret is not configured. Webhook events will not be priocessed.");
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "")
|
||||
logger.warn(
|
||||
"Stripe webhook secret is not configured. Webhook events will not be priocessed."
|
||||
);
|
||||
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, ""));
|
||||
}
|
||||
|
||||
// Only verify the event if you have an endpoint secret defined.
|
||||
@@ -49,7 +50,10 @@ export async function billingWebhookHandler(
|
||||
if (!signature) {
|
||||
logger.info("No stripe signature found in headers.");
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "No stripe signature found in headers")
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"No stripe signature found in headers"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -62,7 +66,10 @@ export async function billingWebhookHandler(
|
||||
} catch (err) {
|
||||
logger.error(`Webhook signature verification failed.`, err);
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "Webhook signature verification failed")
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
"Webhook signature verification failed"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,10 +24,10 @@ import { registry } from "@server/openApi";
|
||||
import { GetCertificateResponse } from "@server/routers/certificates/types";
|
||||
|
||||
const getCertificateSchema = z.strictObject({
|
||||
domainId: z.string(),
|
||||
domain: z.string().min(1).max(255),
|
||||
orgId: z.string()
|
||||
});
|
||||
domainId: z.string(),
|
||||
domain: z.string().min(1).max(255),
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
async function query(domainId: string, domain: string) {
|
||||
const [domainRecord] = await db
|
||||
@@ -42,8 +42,8 @@ async function query(domainId: string, domain: string) {
|
||||
|
||||
let existing: any[] = [];
|
||||
if (domainRecord.type == "ns") {
|
||||
const domainLevelDown = domain.split('.').slice(1).join('.');
|
||||
|
||||
const domainLevelDown = domain.split(".").slice(1).join(".");
|
||||
|
||||
existing = await db
|
||||
.select({
|
||||
certId: certificates.certId,
|
||||
@@ -64,7 +64,7 @@ async function query(domainId: string, domain: string) {
|
||||
eq(certificates.wildcard, true), // only NS domains can have wildcard certs
|
||||
or(
|
||||
eq(certificates.domain, domain),
|
||||
eq(certificates.domain, domainLevelDown),
|
||||
eq(certificates.domain, domainLevelDown)
|
||||
)
|
||||
)
|
||||
);
|
||||
@@ -102,8 +102,7 @@ registry.registerPath({
|
||||
tags: ["Certificate"],
|
||||
request: {
|
||||
params: z.object({
|
||||
domainId: z
|
||||
.string(),
|
||||
domainId: z.string(),
|
||||
domain: z.string().min(1).max(255),
|
||||
orgId: z.string()
|
||||
})
|
||||
@@ -133,7 +132,9 @@ export async function getCertificate(
|
||||
|
||||
if (!cert) {
|
||||
logger.warn(`Certificate not found for domain: ${domainId}`);
|
||||
return next(createHttpError(HttpCode.NOT_FOUND, "Certificate not found"));
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Certificate not found")
|
||||
);
|
||||
}
|
||||
|
||||
return response<GetCertificateResponse>(res, {
|
||||
|
||||
@@ -12,4 +12,4 @@
|
||||
*/
|
||||
|
||||
export * from "./getCertificate";
|
||||
export * from "./restartCertificate";
|
||||
export * from "./restartCertificate";
|
||||
|
||||
@@ -25,9 +25,9 @@ import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const restartCertificateParamsSchema = z.strictObject({
|
||||
certId: z.string().transform(stoi).pipe(z.int().positive()),
|
||||
orgId: z.string()
|
||||
});
|
||||
certId: z.string().transform(stoi).pipe(z.int().positive()),
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
@@ -36,10 +36,7 @@ registry.registerPath({
|
||||
tags: ["Certificate"],
|
||||
request: {
|
||||
params: z.object({
|
||||
certId: z
|
||||
.string()
|
||||
.transform(stoi)
|
||||
.pipe(z.int().positive()),
|
||||
certId: z.string().transform(stoi).pipe(z.int().positive()),
|
||||
orgId: z.string()
|
||||
})
|
||||
},
|
||||
@@ -94,7 +91,7 @@ export async function restartCertificate(
|
||||
.set({
|
||||
status: "pending",
|
||||
errorMessage: null,
|
||||
lastRenewalAttempt: Math.floor(Date.now() / 1000)
|
||||
lastRenewalAttempt: Math.floor(Date.now() / 1000)
|
||||
})
|
||||
.where(eq(certificates.certId, certId));
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ import { CheckDomainAvailabilityResponse } from "@server/routers/domain/types";
|
||||
const paramsSchema = z.strictObject({});
|
||||
|
||||
const querySchema = z.strictObject({
|
||||
subdomain: z.string()
|
||||
});
|
||||
subdomain: z.string()
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
|
||||
@@ -12,4 +12,4 @@
|
||||
*/
|
||||
|
||||
export * from "./checkDomainNamespaceAvailability";
|
||||
export * from "./listDomainNamespaces";
|
||||
export * from "./listDomainNamespaces";
|
||||
|
||||
@@ -26,19 +26,19 @@ import { OpenAPITags, registry } from "@server/openApi";
|
||||
const paramsSchema = z.strictObject({});
|
||||
|
||||
const querySchema = z.strictObject({
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative())
|
||||
});
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative())
|
||||
});
|
||||
|
||||
async function query(limit: number, offset: number) {
|
||||
const res = await db
|
||||
|
||||
@@ -30,7 +30,7 @@ import {
|
||||
verifyUserHasAction,
|
||||
verifyUserIsServerAdmin,
|
||||
verifySiteAccess,
|
||||
verifyClientAccess,
|
||||
verifyClientAccess
|
||||
} from "@server/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
import {
|
||||
@@ -436,6 +436,8 @@ authenticated.get(
|
||||
|
||||
authenticated.post(
|
||||
"/re-key/:clientId/regenerate-client-secret",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyClientAccess,
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateClientSecret
|
||||
@@ -443,15 +445,18 @@ authenticated.post(
|
||||
|
||||
authenticated.post(
|
||||
"/re-key/:siteId/regenerate-site-secret",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifySiteAccess,
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateSiteSecret
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/re-key/:orgId/reGenerate-remote-exit-node-secret",
|
||||
"/re-key/:orgId/regenerate-remote-exit-node-secret",
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription,
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.updateRemoteExitNode),
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateExitNodeSecret
|
||||
);
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
/*
|
||||
* This file is part of a proprietary work.
|
||||
*
|
||||
* Copyright (c) 2025 Fossorial, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This file is licensed under the Fossorial Commercial License.
|
||||
* You may not use this file except in compliance with the License.
|
||||
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
|
||||
*
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
@@ -79,86 +79,72 @@ import semver from "semver";
|
||||
|
||||
// Zod schemas for request validation
|
||||
const getResourceByDomainParamsSchema = z.strictObject({
|
||||
domain: z.string().min(1, "Domain is required")
|
||||
});
|
||||
domain: z.string().min(1, "Domain is required")
|
||||
});
|
||||
|
||||
const getUserSessionParamsSchema = z.strictObject({
|
||||
userSessionId: z.string().min(1, "User session ID is required")
|
||||
});
|
||||
userSessionId: z.string().min(1, "User session ID is required")
|
||||
});
|
||||
|
||||
const getUserOrgRoleParamsSchema = z.strictObject({
|
||||
userId: z.string().min(1, "User ID is required"),
|
||||
orgId: z.string().min(1, "Organization ID is required")
|
||||
});
|
||||
userId: z.string().min(1, "User ID is required"),
|
||||
orgId: z.string().min(1, "Organization ID is required")
|
||||
});
|
||||
|
||||
const getRoleResourceAccessParamsSchema = z.strictObject({
|
||||
roleId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(
|
||||
z.int().positive("Role ID must be a positive integer")
|
||||
),
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(
|
||||
z.int()
|
||||
.positive("Resource ID must be a positive integer")
|
||||
)
|
||||
});
|
||||
roleId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.int().positive("Role ID must be a positive integer")),
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.int().positive("Resource ID must be a positive integer"))
|
||||
});
|
||||
|
||||
const getUserResourceAccessParamsSchema = z.strictObject({
|
||||
userId: z.string().min(1, "User ID is required"),
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(
|
||||
z.int()
|
||||
.positive("Resource ID must be a positive integer")
|
||||
)
|
||||
});
|
||||
userId: z.string().min(1, "User ID is required"),
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.int().positive("Resource ID must be a positive integer"))
|
||||
});
|
||||
|
||||
const getResourceRulesParamsSchema = z.strictObject({
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(
|
||||
z.int()
|
||||
.positive("Resource ID must be a positive integer")
|
||||
)
|
||||
});
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.int().positive("Resource ID must be a positive integer"))
|
||||
});
|
||||
|
||||
const validateResourceSessionTokenParamsSchema = z.strictObject({
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(
|
||||
z.int()
|
||||
.positive("Resource ID must be a positive integer")
|
||||
)
|
||||
});
|
||||
resourceId: z
|
||||
.string()
|
||||
.transform(Number)
|
||||
.pipe(z.int().positive("Resource ID must be a positive integer"))
|
||||
});
|
||||
|
||||
const validateResourceSessionTokenBodySchema = z.strictObject({
|
||||
token: z.string().min(1, "Token is required")
|
||||
});
|
||||
token: z.string().min(1, "Token is required")
|
||||
});
|
||||
|
||||
const validateResourceAccessTokenBodySchema = z.strictObject({
|
||||
accessTokenId: z.string().optional(),
|
||||
resourceId: z.number().optional(),
|
||||
accessToken: z.string()
|
||||
});
|
||||
accessTokenId: z.string().optional(),
|
||||
resourceId: z.number().optional(),
|
||||
accessToken: z.string()
|
||||
});
|
||||
|
||||
// Certificates by domains query validation
|
||||
const getCertificatesByDomainsQuerySchema = z.strictObject({
|
||||
// Accept domains as string or array (domains or domains[])
|
||||
domains: z
|
||||
.union([z.array(z.string().min(1)), z.string().min(1)])
|
||||
.optional(),
|
||||
// Handle array format from query parameters (domains[])
|
||||
"domains[]": z
|
||||
.union([z.array(z.string().min(1)), z.string().min(1)])
|
||||
.optional()
|
||||
});
|
||||
// Accept domains as string or array (domains or domains[])
|
||||
domains: z
|
||||
.union([z.array(z.string().min(1)), z.string().min(1)])
|
||||
.optional(),
|
||||
// Handle array format from query parameters (domains[])
|
||||
"domains[]": z
|
||||
.union([z.array(z.string().min(1)), z.string().min(1)])
|
||||
.optional()
|
||||
});
|
||||
|
||||
// Type exports for request schemas
|
||||
export type GetResourceByDomainParams = z.infer<
|
||||
@@ -566,8 +552,8 @@ hybridRouter.get(
|
||||
);
|
||||
|
||||
const getOrgLoginPageParamsSchema = z.strictObject({
|
||||
orgId: z.string().min(1)
|
||||
});
|
||||
orgId: z.string().min(1)
|
||||
});
|
||||
|
||||
hybridRouter.get(
|
||||
"/org/:orgId/login-page",
|
||||
@@ -1408,8 +1394,16 @@ hybridRouter.post(
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } =
|
||||
parsedParams.data;
|
||||
const {
|
||||
olmId,
|
||||
newtId,
|
||||
ip,
|
||||
port,
|
||||
timestamp,
|
||||
token,
|
||||
publicKey,
|
||||
reachableAt
|
||||
} = parsedParams.data;
|
||||
|
||||
const destinations = await updateAndGenerateEndpointDestinations(
|
||||
olmId,
|
||||
|
||||
@@ -18,7 +18,7 @@ import * as logs from "#private/routers/auditLogs";
|
||||
import {
|
||||
verifyApiKeyHasAction,
|
||||
verifyApiKeyIsRoot,
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyOrgAccess
|
||||
} from "@server/middlewares";
|
||||
import {
|
||||
verifyValidSubscription,
|
||||
@@ -26,7 +26,10 @@ import {
|
||||
} from "#private/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
|
||||
import { unauthenticated as ua, authenticated as a } from "@server/routers/integration";
|
||||
import {
|
||||
unauthenticated as ua,
|
||||
authenticated as a
|
||||
} from "@server/routers/integration";
|
||||
import { logActionAudit } from "#private/middlewares";
|
||||
|
||||
export const unauthenticated = ua;
|
||||
@@ -37,7 +40,7 @@ authenticated.post(
|
||||
verifyApiKeyIsRoot, // We are the only ones who can use root key so its fine
|
||||
verifyApiKeyHasAction(ActionsEnum.sendUsageNotification),
|
||||
logActionAudit(ActionsEnum.sendUsageNotification),
|
||||
org.sendUsageNotification,
|
||||
org.sendUsageNotification
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
@@ -45,7 +48,7 @@ authenticated.delete(
|
||||
verifyApiKeyIsRoot,
|
||||
verifyApiKeyHasAction(ActionsEnum.deleteIdp),
|
||||
logActionAudit(ActionsEnum.deleteIdp),
|
||||
orgIdp.deleteOrgIdp,
|
||||
orgIdp.deleteOrgIdp
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
|
||||
@@ -21,8 +21,8 @@ import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
licenseKey: z.string().min(1).max(255)
|
||||
});
|
||||
licenseKey: z.string().min(1).max(255)
|
||||
});
|
||||
|
||||
export async function activateLicense(
|
||||
req: Request,
|
||||
|
||||
@@ -24,8 +24,8 @@ import { licenseKey } from "@server/db";
|
||||
import license from "#private/license/license";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
licenseKey: z.string().min(1).max(255)
|
||||
});
|
||||
licenseKey: z.string().min(1).max(255)
|
||||
});
|
||||
|
||||
export async function deleteLicenseKey(
|
||||
req: Request,
|
||||
|
||||
@@ -36,13 +36,13 @@ import { build } from "@server/build";
|
||||
import { CreateLoginPageResponse } from "@server/routers/loginPage/types";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
subdomain: z.string().nullable().optional(),
|
||||
domainId: z.string()
|
||||
});
|
||||
subdomain: z.string().nullable().optional(),
|
||||
domainId: z.string()
|
||||
});
|
||||
|
||||
export type CreateLoginPageBody = z.infer<typeof bodySchema>;
|
||||
|
||||
@@ -149,12 +149,20 @@ export async function createLoginPage(
|
||||
|
||||
let returned: LoginPage | undefined;
|
||||
await db.transaction(async (trx) => {
|
||||
|
||||
const orgSites = await trx
|
||||
.select()
|
||||
.from(sites)
|
||||
.innerJoin(exitNodes, eq(exitNodes.exitNodeId, sites.exitNodeId))
|
||||
.where(and(eq(sites.orgId, orgId), eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
|
||||
.innerJoin(
|
||||
exitNodes,
|
||||
eq(exitNodes.exitNodeId, sites.exitNodeId)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(sites.orgId, orgId),
|
||||
eq(exitNodes.type, "gerbil"),
|
||||
eq(exitNodes.online, true)
|
||||
)
|
||||
)
|
||||
.limit(10);
|
||||
|
||||
let exitNodesList = orgSites.map((s) => s.exitNodes);
|
||||
@@ -163,7 +171,12 @@ export async function createLoginPage(
|
||||
exitNodesList = await trx
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(and(eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
|
||||
.where(
|
||||
and(
|
||||
eq(exitNodes.type, "gerbil"),
|
||||
eq(exitNodes.online, true)
|
||||
)
|
||||
)
|
||||
.limit(10);
|
||||
}
|
||||
|
||||
|
||||
@@ -78,15 +78,11 @@ export async function deleteLoginPage(
|
||||
// if (!leftoverLinks.length) {
|
||||
await db
|
||||
.delete(loginPage)
|
||||
.where(
|
||||
eq(loginPage.loginPageId, parsedParams.data.loginPageId)
|
||||
);
|
||||
.where(eq(loginPage.loginPageId, parsedParams.data.loginPageId));
|
||||
|
||||
await db
|
||||
.delete(loginPageOrg)
|
||||
.where(
|
||||
eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId)
|
||||
);
|
||||
.where(eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId));
|
||||
// }
|
||||
|
||||
return response<LoginPage>(res, {
|
||||
|
||||
@@ -23,8 +23,8 @@ import { fromError } from "zod-validation-error";
|
||||
import { GetLoginPageResponse } from "@server/routers/loginPage/types";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
async function query(orgId: string) {
|
||||
const [res] = await db
|
||||
|
||||
@@ -35,7 +35,8 @@ const paramsSchema = z
|
||||
})
|
||||
.strict();
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
const bodySchema = z
|
||||
.strictObject({
|
||||
subdomain: subdomainSchema.nullable().optional(),
|
||||
domainId: z.string().optional()
|
||||
})
|
||||
@@ -86,7 +87,7 @@ export async function updateLoginPage(
|
||||
|
||||
const { loginPageId, orgId } = parsedParams.data;
|
||||
|
||||
if (build === "saas"){
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
if (!subscribed) {
|
||||
@@ -182,7 +183,10 @@ export async function updateLoginPage(
|
||||
}
|
||||
|
||||
// update the full domain if it has changed
|
||||
if (fullDomain && fullDomain !== existingLoginPage?.fullDomain) {
|
||||
if (
|
||||
fullDomain &&
|
||||
fullDomain !== existingLoginPage?.fullDomain
|
||||
) {
|
||||
await db
|
||||
.update(loginPage)
|
||||
.set({ fullDomain })
|
||||
|
||||
@@ -23,9 +23,9 @@ import SupportEmail from "@server/emails/templates/SupportEmail";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
body: z.string().min(1),
|
||||
subject: z.string().min(1).max(255)
|
||||
});
|
||||
body: z.string().min(1),
|
||||
subject: z.string().min(1).max(255)
|
||||
});
|
||||
|
||||
export async function sendSupportEmail(
|
||||
req: Request,
|
||||
@@ -66,6 +66,7 @@ export async function sendSupportEmail(
|
||||
{
|
||||
name: req.user?.email || "Support User",
|
||||
to: "support@pangolin.net",
|
||||
replyTo: req.user?.email || undefined,
|
||||
from: config.getNoReplyEmail(),
|
||||
subject: `Support Request: ${subject}`
|
||||
}
|
||||
|
||||
@@ -11,4 +11,4 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./sendUsageNotifications";
|
||||
export * from "./sendUsageNotifications";
|
||||
|
||||
@@ -35,10 +35,12 @@ const sendUsageNotificationBodySchema = z.object({
|
||||
notificationType: z.enum(["approaching_70", "approaching_90", "reached"]),
|
||||
limitName: z.string(),
|
||||
currentUsage: z.number(),
|
||||
usageLimit: z.number(),
|
||||
usageLimit: z.number()
|
||||
});
|
||||
|
||||
type SendUsageNotificationRequest = z.infer<typeof sendUsageNotificationBodySchema>;
|
||||
type SendUsageNotificationRequest = z.infer<
|
||||
typeof sendUsageNotificationBodySchema
|
||||
>;
|
||||
|
||||
export type SendUsageNotificationResponse = {
|
||||
success: boolean;
|
||||
@@ -97,17 +99,13 @@ async function getOrgAdmins(orgId: string) {
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.orgId, orgId),
|
||||
or(
|
||||
eq(userOrgs.isOwner, true),
|
||||
eq(roles.isAdmin, true)
|
||||
)
|
||||
or(eq(userOrgs.isOwner, true), eq(roles.isAdmin, true))
|
||||
)
|
||||
);
|
||||
|
||||
// Filter to only include users with verified emails
|
||||
const orgAdmins = admins.filter(admin =>
|
||||
admin.email &&
|
||||
admin.email.length > 0
|
||||
const orgAdmins = admins.filter(
|
||||
(admin) => admin.email && admin.email.length > 0
|
||||
);
|
||||
|
||||
return orgAdmins;
|
||||
@@ -119,7 +117,9 @@ export async function sendUsageNotification(
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = sendUsageNotificationParamsSchema.safeParse(req.params);
|
||||
const parsedParams = sendUsageNotificationParamsSchema.safeParse(
|
||||
req.params
|
||||
);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
@@ -140,12 +140,8 @@ export async function sendUsageNotification(
|
||||
}
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
const {
|
||||
notificationType,
|
||||
limitName,
|
||||
currentUsage,
|
||||
usageLimit,
|
||||
} = parsedBody.data;
|
||||
const { notificationType, limitName, currentUsage, usageLimit } =
|
||||
parsedBody.data;
|
||||
|
||||
// Verify organization exists
|
||||
const org = await db
|
||||
@@ -192,7 +188,10 @@ export async function sendUsageNotification(
|
||||
let template;
|
||||
let subject;
|
||||
|
||||
if (notificationType === "approaching_70" || notificationType === "approaching_90") {
|
||||
if (
|
||||
notificationType === "approaching_70" ||
|
||||
notificationType === "approaching_90"
|
||||
) {
|
||||
template = NotifyUsageLimitApproaching({
|
||||
email: admin.email,
|
||||
limitName,
|
||||
@@ -220,10 +219,15 @@ export async function sendUsageNotification(
|
||||
|
||||
emailsSent++;
|
||||
adminEmails.push(admin.email);
|
||||
|
||||
logger.info(`Usage notification sent to admin ${admin.email} for org ${orgId}`);
|
||||
|
||||
logger.info(
|
||||
`Usage notification sent to admin ${admin.email} for org ${orgId}`
|
||||
);
|
||||
} catch (emailError) {
|
||||
logger.error(`Failed to send usage notification to ${admin.email}:`, emailError);
|
||||
logger.error(
|
||||
`Failed to send usage notification to ${admin.email}:`,
|
||||
emailError
|
||||
);
|
||||
// Continue with other admins even if one fails
|
||||
}
|
||||
}
|
||||
@@ -239,11 +243,13 @@ export async function sendUsageNotification(
|
||||
message: `Usage notifications sent to ${emailsSent} administrators`,
|
||||
status: HttpCode.OK
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
logger.error("Error sending usage notifications:", error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to send usage notifications")
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to send usage notifications"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,19 +32,19 @@ import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types";
|
||||
const paramsSchema = z.strictObject({ orgId: z.string().nonempty() });
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
name: z.string().nonempty(),
|
||||
clientId: z.string().nonempty(),
|
||||
clientSecret: z.string().nonempty(),
|
||||
authUrl: z.url(),
|
||||
tokenUrl: z.url(),
|
||||
identifierPath: z.string().nonempty(),
|
||||
emailPath: z.string().optional(),
|
||||
namePath: z.string().optional(),
|
||||
scopes: z.string().nonempty(),
|
||||
autoProvision: z.boolean().optional(),
|
||||
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
|
||||
roleMapping: z.string().optional()
|
||||
});
|
||||
name: z.string().nonempty(),
|
||||
clientId: z.string().nonempty(),
|
||||
clientSecret: z.string().nonempty(),
|
||||
authUrl: z.url(),
|
||||
tokenUrl: z.url(),
|
||||
identifierPath: z.string().nonempty(),
|
||||
emailPath: z.string().optional(),
|
||||
namePath: z.string().optional(),
|
||||
scopes: z.string().nonempty(),
|
||||
autoProvision: z.boolean().optional(),
|
||||
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
|
||||
roleMapping: z.string().optional()
|
||||
});
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "put",
|
||||
@@ -158,7 +158,10 @@ export async function createOrgOidcIdp(
|
||||
});
|
||||
});
|
||||
|
||||
const redirectUrl = await generateOidcRedirectUrl(idpId as number, orgId);
|
||||
const redirectUrl = await generateOidcRedirectUrl(
|
||||
idpId as number,
|
||||
orgId
|
||||
);
|
||||
|
||||
return response<CreateOrgIdpResponse>(res, {
|
||||
data: {
|
||||
|
||||
@@ -66,12 +66,7 @@ export async function deleteOrgIdp(
|
||||
.where(eq(idp.idpId, idpId));
|
||||
|
||||
if (!existingIdp) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"IdP not found"
|
||||
)
|
||||
);
|
||||
return next(createHttpError(HttpCode.NOT_FOUND, "IdP not found"));
|
||||
}
|
||||
|
||||
// Delete the IDP and its related records in a transaction
|
||||
@@ -82,14 +77,10 @@ export async function deleteOrgIdp(
|
||||
.where(eq(idpOidcConfig.idpId, idpId));
|
||||
|
||||
// Delete IDP-org mappings
|
||||
await trx
|
||||
.delete(idpOrg)
|
||||
.where(eq(idpOrg.idpId, idpId));
|
||||
await trx.delete(idpOrg).where(eq(idpOrg.idpId, idpId));
|
||||
|
||||
// Delete the IDP itself
|
||||
await trx
|
||||
.delete(idp)
|
||||
.where(eq(idp.idpId, idpId));
|
||||
await trx.delete(idp).where(eq(idp.idpId, idpId));
|
||||
});
|
||||
|
||||
return response<null>(res, {
|
||||
|
||||
@@ -93,7 +93,10 @@ export async function getOrgIdp(
|
||||
idpRes.idpOidcConfig!.clientId = decrypt(clientId, key);
|
||||
}
|
||||
|
||||
const redirectUrl = await generateOidcRedirectUrl(idpRes.idp.idpId, orgId);
|
||||
const redirectUrl = await generateOidcRedirectUrl(
|
||||
idpRes.idp.idpId,
|
||||
orgId
|
||||
);
|
||||
|
||||
return response<GetOrgIdpResponse>(res, {
|
||||
data: {
|
||||
|
||||
@@ -15,4 +15,4 @@ export * from "./createOrgOidcIdp";
|
||||
export * from "./getOrgIdp";
|
||||
export * from "./listOrgIdps";
|
||||
export * from "./updateOrgOidcIdp";
|
||||
export * from "./deleteOrgIdp";
|
||||
export * from "./deleteOrgIdp";
|
||||
|
||||
@@ -25,23 +25,23 @@ import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { ListOrgIdpsResponse } from "@server/routers/orgIdp/types";
|
||||
|
||||
const querySchema = z.strictObject({
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative())
|
||||
});
|
||||
limit: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative()),
|
||||
offset: z
|
||||
.string()
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative())
|
||||
});
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string().nonempty()
|
||||
});
|
||||
orgId: z.string().nonempty()
|
||||
});
|
||||
|
||||
async function query(orgId: string, limit: number, offset: number) {
|
||||
const res = await db
|
||||
|
||||
@@ -36,18 +36,18 @@ const paramsSchema = z
|
||||
.strict();
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
name: z.string().optional(),
|
||||
clientId: z.string().optional(),
|
||||
clientSecret: z.string().optional(),
|
||||
authUrl: z.string().optional(),
|
||||
tokenUrl: z.string().optional(),
|
||||
identifierPath: z.string().optional(),
|
||||
emailPath: z.string().optional(),
|
||||
namePath: z.string().optional(),
|
||||
scopes: z.string().optional(),
|
||||
autoProvision: z.boolean().optional(),
|
||||
roleMapping: z.string().optional()
|
||||
});
|
||||
name: z.string().optional(),
|
||||
clientId: z.string().optional(),
|
||||
clientSecret: z.string().optional(),
|
||||
authUrl: z.string().optional(),
|
||||
tokenUrl: z.string().optional(),
|
||||
identifierPath: z.string().optional(),
|
||||
emailPath: z.string().optional(),
|
||||
namePath: z.string().optional(),
|
||||
scopes: z.string().optional(),
|
||||
autoProvision: z.boolean().optional(),
|
||||
roleMapping: z.string().optional()
|
||||
});
|
||||
|
||||
export type UpdateOrgIdpResponse = {
|
||||
idpId: number;
|
||||
|
||||
@@ -13,4 +13,4 @@
|
||||
|
||||
export * from "./reGenerateClientSecret";
|
||||
export * from "./reGenerateSiteSecret";
|
||||
export * from "./reGenerateExitNodeSecret";
|
||||
export * from "./reGenerateExitNodeSecret";
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, olms, } from "@server/db";
|
||||
import { db, Olm, olms } from "@server/db";
|
||||
import { clients } from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
@@ -23,37 +23,19 @@ import { eq, and } from "drizzle-orm";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import { disconnectClient, sendToClient } from "#private/routers/ws";
|
||||
|
||||
const reGenerateSecretParamsSchema = z.strictObject({
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
const reGenerateSecretBodySchema = z.strictObject({
|
||||
olmId: z.string().min(1).optional(),
|
||||
secret: z.string().min(1).optional(),
|
||||
|
||||
});
|
||||
|
||||
export type ReGenerateSecretBody = z.infer<typeof reGenerateSecretBodySchema>;
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/re-key/{clientId}/regenerate-client-secret",
|
||||
description: "Regenerate a client's OLM credentials by its client ID.",
|
||||
tags: [OpenAPITags.Client],
|
||||
request: {
|
||||
params: reGenerateSecretParamsSchema,
|
||||
body: {
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: reGenerateSecretBodySchema
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
responses: {}
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
const reGenerateSecretBodySchema = z.strictObject({
|
||||
// olmId: z.string().min(1).optional(),
|
||||
secret: z.string().min(1),
|
||||
disconnect: z.boolean().optional().default(true)
|
||||
});
|
||||
|
||||
export type ReGenerateSecretBody = z.infer<typeof reGenerateSecretBodySchema>;
|
||||
|
||||
export async function reGenerateClientSecret(
|
||||
req: Request,
|
||||
@@ -71,7 +53,7 @@ export async function reGenerateClientSecret(
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId, secret } = parsedBody.data;
|
||||
const { secret, disconnect } = parsedBody.data;
|
||||
|
||||
const parsedParams = reGenerateSecretParamsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
@@ -85,11 +67,7 @@ export async function reGenerateClientSecret(
|
||||
|
||||
const { clientId } = parsedParams.data;
|
||||
|
||||
let secretHash = undefined;
|
||||
if (secret) {
|
||||
secretHash = await hashPassword(secret);
|
||||
}
|
||||
|
||||
const secretHash = await hashPassword(secret);
|
||||
|
||||
// Fetch the client to make sure it exists and the user has access to it
|
||||
const [client] = await db
|
||||
@@ -107,24 +85,59 @@ export async function reGenerateClientSecret(
|
||||
);
|
||||
}
|
||||
|
||||
const [existingOlm] = await db
|
||||
const existingOlms = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.clientId, clientId))
|
||||
.limit(1);
|
||||
.where(eq(olms.clientId, clientId));
|
||||
|
||||
if (existingOlm && olmId && secretHash) {
|
||||
await db
|
||||
.update(olms)
|
||||
.set({
|
||||
olmId,
|
||||
secretHash
|
||||
})
|
||||
.where(eq(olms.clientId, clientId));
|
||||
if (existingOlms.length === 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`No OLM found for client ID ${clientId}`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (existingOlms.length > 1) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
`Multiple OLM entries found for client ID ${clientId}`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await db
|
||||
.update(olms)
|
||||
.set({
|
||||
secretHash
|
||||
})
|
||||
.where(eq(olms.olmId, existingOlms[0].olmId));
|
||||
|
||||
// Only disconnect if explicitly requested
|
||||
if (disconnect) {
|
||||
const payload = {
|
||||
type: `olm/terminate`,
|
||||
data: {}
|
||||
};
|
||||
// Don't await this to prevent blocking the response
|
||||
sendToClient(existingOlms[0].olmId, payload).catch((error) => {
|
||||
logger.error(
|
||||
"Failed to send termination message to olm:",
|
||||
error
|
||||
);
|
||||
});
|
||||
|
||||
disconnectClient(existingOlms[0].olmId).catch((error) => {
|
||||
logger.error("Failed to disconnect olm after re-key:", error);
|
||||
});
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: existingOlm,
|
||||
data: {
|
||||
olmId: existingOlms[0].olmId
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Credentials regenerated successfully",
|
||||
|
||||
@@ -12,7 +12,14 @@
|
||||
*/
|
||||
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg } from "@server/db";
|
||||
import {
|
||||
db,
|
||||
exitNodes,
|
||||
exitNodeOrgs,
|
||||
ExitNode,
|
||||
ExitNodeOrg,
|
||||
RemoteExitNode
|
||||
} from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { z } from "zod";
|
||||
import { remoteExitNodes } from "@server/db";
|
||||
@@ -22,35 +29,17 @@ import { fromError } from "zod-validation-error";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import logger from "@server/logger";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { UpdateRemoteExitNodeResponse } from "@server/routers/remoteExitNode/types";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { disconnectClient, sendToClient } from "#private/routers/ws";
|
||||
|
||||
export const paramsSchema = z.object({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
remoteExitNodeId: z.string().length(15),
|
||||
secret: z.string().length(48)
|
||||
});
|
||||
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/re-key/{orgId}/regenerate-secret",
|
||||
description: "Regenerate a exit node credentials by its org ID.",
|
||||
tags: [OpenAPITags.Org],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
body: {
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: bodySchema
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
responses: {}
|
||||
remoteExitNodeId: z.string().length(15),
|
||||
secret: z.string().length(48),
|
||||
disconnect: z.boolean().optional().default(true)
|
||||
});
|
||||
|
||||
export async function reGenerateExitNodeSecret(
|
||||
@@ -79,13 +68,7 @@ export async function reGenerateExitNodeSecret(
|
||||
);
|
||||
}
|
||||
|
||||
const { remoteExitNodeId, secret } = parsedBody.data;
|
||||
|
||||
if (req.user && !req.userOrgRoleId) {
|
||||
return next(
|
||||
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
|
||||
);
|
||||
}
|
||||
const { remoteExitNodeId, secret, disconnect } = parsedBody.data;
|
||||
|
||||
const [existingRemoteExitNode] = await db
|
||||
.select()
|
||||
@@ -94,7 +77,10 @@ export async function reGenerateExitNodeSecret(
|
||||
|
||||
if (!existingRemoteExitNode) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Remote Exit Node does not exist")
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"Remote Exit Node does not exist"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@@ -105,15 +91,39 @@ export async function reGenerateExitNodeSecret(
|
||||
.set({ secretHash })
|
||||
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId));
|
||||
|
||||
return response<UpdateRemoteExitNodeResponse>(res, {
|
||||
data: {
|
||||
remoteExitNodeId,
|
||||
secret,
|
||||
},
|
||||
// Only disconnect if explicitly requested
|
||||
if (disconnect) {
|
||||
const payload = {
|
||||
type: `remoteExitNode/terminate`,
|
||||
data: {}
|
||||
};
|
||||
// Don't await this to prevent blocking the response
|
||||
sendToClient(
|
||||
existingRemoteExitNode.remoteExitNodeId,
|
||||
payload
|
||||
).catch((error) => {
|
||||
logger.error(
|
||||
"Failed to send termination message to remote exit node:",
|
||||
error
|
||||
);
|
||||
});
|
||||
|
||||
disconnectClient(existingRemoteExitNode.remoteExitNodeId).catch(
|
||||
(error) => {
|
||||
logger.error(
|
||||
"Failed to disconnect remote exit node after re-key:",
|
||||
error
|
||||
);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Remote Exit Node secret updated successfully",
|
||||
status: HttpCode.OK,
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (e) {
|
||||
logger.error("Failed to update remoteExitNode", e);
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, newts, sites } from "@server/db";
|
||||
import { db, Newt, newts, sites } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
@@ -22,38 +22,19 @@ import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
import { addPeer } from "@server/routers/gerbil/peers";
|
||||
|
||||
import { addPeer, deletePeer } from "@server/routers/gerbil/peers";
|
||||
import { getAllowedIps } from "@server/routers/target/helpers";
|
||||
import { disconnectClient, sendToClient } from "#private/routers/ws";
|
||||
|
||||
const updateSiteParamsSchema = z.strictObject({
|
||||
siteId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
siteId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
const updateSiteBodySchema = z.strictObject({
|
||||
type: z.enum(["newt", "wireguard"]),
|
||||
newtId: z.string().min(1).max(255).optional(),
|
||||
newtSecret: z.string().min(1).max(255).optional(),
|
||||
exitNodeId: z.int().positive().optional(),
|
||||
pubKey: z.string().optional(),
|
||||
subnet: z.string().optional(),
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/re-key/{siteId}/regenerate-site-secret",
|
||||
description: "Regenerate a site's Newt or WireGuard credentials by its site ID.",
|
||||
tags: [OpenAPITags.Site],
|
||||
request: {
|
||||
params: updateSiteParamsSchema,
|
||||
body: {
|
||||
content: {
|
||||
"application/json": {
|
||||
schema: updateSiteBodySchema,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
responses: {},
|
||||
type: z.enum(["newt", "wireguard"]),
|
||||
secret: z.string().min(1).max(255).optional(),
|
||||
pubKey: z.string().optional(),
|
||||
disconnect: z.boolean().optional().default(true)
|
||||
});
|
||||
|
||||
export async function reGenerateSiteSecret(
|
||||
@@ -65,74 +46,149 @@ export async function reGenerateSiteSecret(
|
||||
const parsedParams = updateSiteParamsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, fromError(parsedParams.error).toString())
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const parsedBody = updateSiteBodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, fromError(parsedBody.error).toString())
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { siteId } = parsedParams.data;
|
||||
const { type, exitNodeId, pubKey, subnet, newtId, newtSecret } = parsedBody.data;
|
||||
|
||||
let updatedSite = undefined;
|
||||
const { type, pubKey, secret, disconnect } = parsedBody.data;
|
||||
|
||||
let existingNewt: Newt | null = null;
|
||||
if (type === "newt") {
|
||||
if (!newtSecret) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "newtSecret is required for newt sites")
|
||||
);
|
||||
}
|
||||
|
||||
const secretHash = await hashPassword(newtSecret);
|
||||
|
||||
updatedSite = await db
|
||||
.update(newts)
|
||||
.set({
|
||||
newtId,
|
||||
secretHash,
|
||||
})
|
||||
.where(eq(newts.siteId, siteId))
|
||||
.returning();
|
||||
|
||||
logger.info(`Regenerated Newt credentials for site ${siteId}`);
|
||||
|
||||
} else if (type === "wireguard") {
|
||||
if (!pubKey) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Public key is required for wireguard sites")
|
||||
);
|
||||
}
|
||||
|
||||
if (!exitNodeId) {
|
||||
if (!secret) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Exit node ID is required for wireguard sites"
|
||||
"newtSecret is required for newt sites"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const secretHash = await hashPassword(secret);
|
||||
|
||||
// get the newt to verify it exists
|
||||
const existingNewts = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.siteId, siteId));
|
||||
|
||||
if (existingNewts.length === 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`No Newt found for site ID ${siteId}`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (existingNewts.length > 1) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
`Multiple Newts found for site ID ${siteId}`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
existingNewt = existingNewts[0];
|
||||
|
||||
// update the secret on the existing newt
|
||||
await db
|
||||
.update(newts)
|
||||
.set({
|
||||
secretHash
|
||||
})
|
||||
.where(eq(newts.newtId, existingNewts[0].newtId));
|
||||
|
||||
// Only disconnect if explicitly requested
|
||||
if (disconnect) {
|
||||
const payload = {
|
||||
type: `newt/wg/terminate`,
|
||||
data: {}
|
||||
};
|
||||
// Don't await this to prevent blocking the response
|
||||
sendToClient(existingNewts[0].newtId, payload).catch(
|
||||
(error) => {
|
||||
logger.error(
|
||||
"Failed to send termination message to newt:",
|
||||
error
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
disconnectClient(existingNewts[0].newtId).catch((error) => {
|
||||
logger.error(
|
||||
"Failed to disconnect newt after re-key:",
|
||||
error
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
logger.info(`Regenerated Newt credentials for site ${siteId}`);
|
||||
} else if (type === "wireguard") {
|
||||
if (!pubKey) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Public key is required for wireguard sites"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
updatedSite = await db.transaction(async (tx) => {
|
||||
await addPeer(exitNodeId, {
|
||||
const [site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
|
||||
if (!site) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Site with ID ${siteId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await db
|
||||
.update(sites)
|
||||
.set({ pubKey })
|
||||
.where(eq(sites.siteId, siteId));
|
||||
|
||||
if (!site) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Site with ID ${siteId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (site.exitNodeId && site.subnet) {
|
||||
await deletePeer(site.exitNodeId, site.pubKey!); // the old pubkey
|
||||
await addPeer(site.exitNodeId, {
|
||||
publicKey: pubKey,
|
||||
allowedIps: subnet ? [subnet] : [],
|
||||
allowedIps: await getAllowedIps(site.siteId)
|
||||
});
|
||||
const result = await tx
|
||||
.update(sites)
|
||||
.set({ pubKey })
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.returning();
|
||||
}
|
||||
|
||||
return result;
|
||||
});
|
||||
|
||||
logger.info(`Regenerated WireGuard credentials for site ${siteId}`);
|
||||
logger.info(
|
||||
`Regenerated WireGuard credentials for site ${siteId}`
|
||||
);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`Transaction failed while regenerating WireGuard secret for site ${siteId}`,
|
||||
@@ -148,17 +204,21 @@ export async function reGenerateSiteSecret(
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
data: updatedSite,
|
||||
data: {
|
||||
newtId: existingNewt ? existingNewt.newtId : undefined
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Credentials regenerated successfully",
|
||||
status: HttpCode.OK,
|
||||
status: HttpCode.OK
|
||||
});
|
||||
|
||||
} catch (error) {
|
||||
logger.error("Unexpected error in reGenerateSiteSecret", error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An unexpected error occurred")
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"An unexpected error occurred"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,9 +36,9 @@ export const paramsSchema = z.object({
|
||||
});
|
||||
|
||||
const bodySchema = z.strictObject({
|
||||
remoteExitNodeId: z.string().length(15),
|
||||
secret: z.string().length(48)
|
||||
});
|
||||
remoteExitNodeId: z.string().length(15),
|
||||
secret: z.string().length(48)
|
||||
});
|
||||
|
||||
export type CreateRemoteExitNodeBody = z.infer<typeof bodySchema>;
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ import { usageService } from "@server/lib/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string().min(1),
|
||||
remoteExitNodeId: z.string().min(1)
|
||||
});
|
||||
orgId: z.string().min(1),
|
||||
remoteExitNodeId: z.string().min(1)
|
||||
});
|
||||
|
||||
export async function deleteRemoteExitNode(
|
||||
req: Request,
|
||||
|
||||
@@ -24,9 +24,9 @@ import { fromError } from "zod-validation-error";
|
||||
import { GetRemoteExitNodeResponse } from "@server/routers/remoteExitNode/types";
|
||||
|
||||
const getRemoteExitNodeSchema = z.strictObject({
|
||||
orgId: z.string().min(1),
|
||||
remoteExitNodeId: z.string().min(1)
|
||||
});
|
||||
orgId: z.string().min(1),
|
||||
remoteExitNodeId: z.string().min(1)
|
||||
});
|
||||
|
||||
async function query(remoteExitNodeId: string) {
|
||||
const [remoteExitNode] = await db
|
||||
|
||||
@@ -55,7 +55,8 @@ export async function getRemoteExitNodeToken(
|
||||
|
||||
try {
|
||||
if (token) {
|
||||
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
|
||||
const { session, remoteExitNode } =
|
||||
await validateRemoteExitNodeSessionToken(token);
|
||||
if (session) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
@@ -103,7 +104,10 @@ export async function getRemoteExitNodeToken(
|
||||
}
|
||||
|
||||
const resToken = generateSessionToken();
|
||||
await createRemoteExitNodeSession(resToken, existingRemoteExitNode.remoteExitNodeId);
|
||||
await createRemoteExitNodeSession(
|
||||
resToken,
|
||||
existingRemoteExitNode.remoteExitNodeId
|
||||
);
|
||||
|
||||
// logger.debug(`Created RemoteExitNode token response: ${JSON.stringify(resToken)}`);
|
||||
|
||||
|
||||
@@ -33,7 +33,9 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
|
||||
|
||||
offlineCheckerInterval = setInterval(async () => {
|
||||
try {
|
||||
const twoMinutesAgo = Math.floor((Date.now() - OFFLINE_THRESHOLD_MS) / 1000);
|
||||
const twoMinutesAgo = Math.floor(
|
||||
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
|
||||
);
|
||||
|
||||
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
|
||||
const newlyOfflineNodes = await db
|
||||
@@ -48,11 +50,13 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
|
||||
isNull(exitNodes.lastPing)
|
||||
)
|
||||
)
|
||||
).returning();
|
||||
|
||||
)
|
||||
.returning();
|
||||
|
||||
// Update the sites to offline if they have not pinged either
|
||||
const exitNodeIds = newlyOfflineNodes.map(node => node.exitNodeId);
|
||||
const exitNodeIds = newlyOfflineNodes.map(
|
||||
(node) => node.exitNodeId
|
||||
);
|
||||
|
||||
const sitesOnNode = await db
|
||||
.select()
|
||||
@@ -63,10 +67,10 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
|
||||
inArray(sites.exitNodeId, exitNodeIds)
|
||||
)
|
||||
);
|
||||
|
||||
|
||||
// loop through the sites and process their lastBandwidthUpdate as an iso string and if its more than 1 minute old then mark the site offline
|
||||
for (const site of sitesOnNode) {
|
||||
if (!site.lastBandwidthUpdate) {
|
||||
if (!site.lastBandwidthUpdate) {
|
||||
continue;
|
||||
}
|
||||
const lastBandwidthUpdate = new Date(site.lastBandwidthUpdate);
|
||||
@@ -77,13 +81,12 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
|
||||
.where(eq(sites.siteId, site.siteId));
|
||||
}
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
logger.error("Error in offline checker interval", { error });
|
||||
}
|
||||
}, OFFLINE_CHECK_INTERVAL);
|
||||
|
||||
logger.info("Started offline checker interval");
|
||||
logger.debug("Started offline checker interval");
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -100,7 +103,9 @@ export const stopRemoteExitNodeOfflineChecker = (): void => {
|
||||
/**
|
||||
* Handles ping messages from clients and responds with pong
|
||||
*/
|
||||
export const handleRemoteExitNodePingMessage: MessageHandler = async (context) => {
|
||||
export const handleRemoteExitNodePingMessage: MessageHandler = async (
|
||||
context
|
||||
) => {
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const remoteExitNode = c as RemoteExitNode;
|
||||
|
||||
@@ -120,7 +125,7 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
|
||||
.update(exitNodes)
|
||||
.set({
|
||||
lastPing: Math.floor(Date.now() / 1000),
|
||||
online: true,
|
||||
online: true
|
||||
})
|
||||
.where(eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId));
|
||||
} catch (error) {
|
||||
@@ -131,10 +136,10 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
|
||||
message: {
|
||||
type: "pong",
|
||||
data: {
|
||||
timestamp: new Date().toISOString(),
|
||||
timestamp: new Date().toISOString()
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
excludeSender: false
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
@@ -29,7 +29,8 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
|
||||
return;
|
||||
}
|
||||
|
||||
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } = message.data;
|
||||
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } =
|
||||
message.data;
|
||||
|
||||
if (!remoteExitNodeVersion) {
|
||||
logger.warn("Remote exit node version not found");
|
||||
@@ -39,7 +40,10 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
|
||||
// update the version
|
||||
await db
|
||||
.update(remoteExitNodes)
|
||||
.set({ version: remoteExitNodeVersion, secondaryVersion: remoteExitNodeSecondaryVersion })
|
||||
.set({
|
||||
version: remoteExitNodeVersion,
|
||||
secondaryVersion: remoteExitNodeSecondaryVersion
|
||||
})
|
||||
.where(
|
||||
eq(
|
||||
remoteExitNodes.remoteExitNodeId,
|
||||
|
||||
@@ -24,8 +24,8 @@ import { fromError } from "zod-validation-error";
|
||||
import { ListRemoteExitNodesResponse } from "@server/routers/remoteExitNode/types";
|
||||
|
||||
const listRemoteExitNodesParamsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const listRemoteExitNodesSchema = z.object({
|
||||
limit: z
|
||||
|
||||
@@ -22,8 +22,8 @@ import { z } from "zod";
|
||||
import { PickRemoteExitNodeDefaultsResponse } from "@server/routers/remoteExitNode/types";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
export async function pickRemoteExitNodeDefaults(
|
||||
req: Request,
|
||||
|
||||
@@ -38,7 +38,9 @@ export async function quickStartRemoteExitNode(
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedBody = quickStartRemoteExitNodeBodySchema.safeParse(req.body);
|
||||
const parsedBody = quickStartRemoteExitNodeBodySchema.safeParse(
|
||||
req.body
|
||||
);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
|
||||
@@ -11,4 +11,4 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
export * from "./ws";
|
||||
export * from "./ws";
|
||||
|
||||
@@ -23,4 +23,4 @@ export const messageHandlers: Record<string, MessageHandler> = {
|
||||
"remoteExitNode/ping": handleRemoteExitNodePingMessage
|
||||
};
|
||||
|
||||
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
|
||||
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
|
||||
|
||||
@@ -37,7 +37,14 @@ import { validateRemoteExitNodeSessionToken } from "#private/auth/sessions/remot
|
||||
import { rateLimitService } from "#private/lib/rateLimit";
|
||||
import { messageHandlers } from "@server/routers/ws/messageHandlers";
|
||||
import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers";
|
||||
import { AuthenticatedWebSocket, ClientType, WSMessage, TokenPayload, WebSocketRequest, RedisMessage } from "@server/routers/ws";
|
||||
import {
|
||||
AuthenticatedWebSocket,
|
||||
ClientType,
|
||||
WSMessage,
|
||||
TokenPayload,
|
||||
WebSocketRequest,
|
||||
RedisMessage
|
||||
} from "@server/routers/ws";
|
||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
|
||||
// Merge public and private message handlers
|
||||
@@ -55,9 +62,9 @@ const processMessage = async (
|
||||
try {
|
||||
const message: WSMessage = JSON.parse(data.toString());
|
||||
|
||||
logger.debug(
|
||||
`Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
|
||||
);
|
||||
// logger.debug(
|
||||
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
|
||||
// );
|
||||
|
||||
if (!message.type || typeof message.type !== "string") {
|
||||
throw new Error("Invalid message format: missing or invalid type");
|
||||
@@ -216,7 +223,7 @@ const initializeRedisSubscription = async (): Promise<void> => {
|
||||
// Each node is responsible for restoring its own connection state to Redis
|
||||
// This approach is more efficient than cross-node coordination because:
|
||||
// 1. Each node knows its own connections (source of truth)
|
||||
// 2. No network overhead from broadcasting state between nodes
|
||||
// 2. No network overhead from broadcasting state between nodes
|
||||
// 3. No race conditions from simultaneous updates
|
||||
// 4. Redis becomes eventually consistent as each node restores independently
|
||||
// 5. Simpler logic with better fault tolerance
|
||||
@@ -233,8 +240,10 @@ const recoverConnectionState = async (): Promise<void> => {
|
||||
// Each node simply restores its own local connections to Redis
|
||||
// This is the source of truth - no need for cross-node coordination
|
||||
await restoreLocalConnectionsToRedis();
|
||||
|
||||
logger.info("Redis connection state recovery completed - restored local state");
|
||||
|
||||
logger.info(
|
||||
"Redis connection state recovery completed - restored local state"
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Error during Redis recovery:", error);
|
||||
} finally {
|
||||
@@ -251,8 +260,10 @@ const restoreLocalConnectionsToRedis = async (): Promise<void> => {
|
||||
try {
|
||||
// Restore all current local connections to Redis
|
||||
for (const [clientId, clients] of connectedClients.entries()) {
|
||||
const validClients = clients.filter(client => client.readyState === WebSocket.OPEN);
|
||||
|
||||
const validClients = clients.filter(
|
||||
(client) => client.readyState === WebSocket.OPEN
|
||||
);
|
||||
|
||||
if (validClients.length > 0) {
|
||||
// Add this node to the client's connection list
|
||||
await redisManager.sadd(getConnectionsKey(clientId), NODE_ID);
|
||||
@@ -303,7 +314,10 @@ const addClient = async (
|
||||
Date.now().toString()
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to add client to Redis tracking (connection still functional locally):", error);
|
||||
logger.error(
|
||||
"Failed to add client to Redis tracking (connection still functional locally):",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -326,9 +340,14 @@ const removeClient = async (
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
|
||||
await redisManager.del(getNodeConnectionsKey(NODE_ID, clientId));
|
||||
await redisManager.del(
|
||||
getNodeConnectionsKey(NODE_ID, clientId)
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to remove client from Redis tracking (cleanup will occur on recovery):", error);
|
||||
logger.error(
|
||||
"Failed to remove client from Redis tracking (cleanup will occur on recovery):",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,7 +364,10 @@ const removeClient = async (
|
||||
ws.connectionId
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to remove specific connection from Redis tracking:", error);
|
||||
logger.error(
|
||||
"Failed to remove specific connection from Redis tracking:",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -372,7 +394,9 @@ const sendToClientLocal = async (
|
||||
}
|
||||
});
|
||||
|
||||
logger.debug(`sendToClient: Message type ${message.type} sent to clientId ${clientId}`);
|
||||
logger.debug(
|
||||
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
|
||||
);
|
||||
|
||||
return true;
|
||||
};
|
||||
@@ -411,14 +435,22 @@ const sendToClient = async (
|
||||
fromNodeId: NODE_ID
|
||||
};
|
||||
|
||||
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
|
||||
await redisManager.publish(
|
||||
REDIS_CHANNEL,
|
||||
JSON.stringify(redisMessage)
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to send message via Redis, message may be lost:", error);
|
||||
logger.error(
|
||||
"Failed to send message via Redis, message may be lost:",
|
||||
error
|
||||
);
|
||||
// Continue execution - local delivery already attempted
|
||||
}
|
||||
} else if (!localSent && !redisManager.isRedisEnabled()) {
|
||||
// Redis is disabled or unavailable - log that we couldn't deliver to remote nodes
|
||||
logger.debug(`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`);
|
||||
logger.debug(
|
||||
`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`
|
||||
);
|
||||
}
|
||||
|
||||
return localSent;
|
||||
@@ -441,13 +473,21 @@ const broadcastToAllExcept = async (
|
||||
fromNodeId: NODE_ID
|
||||
};
|
||||
|
||||
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
|
||||
await redisManager.publish(
|
||||
REDIS_CHANNEL,
|
||||
JSON.stringify(redisMessage)
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Failed to broadcast message via Redis, remote nodes may not receive it:", error);
|
||||
logger.error(
|
||||
"Failed to broadcast message via Redis, remote nodes may not receive it:",
|
||||
error
|
||||
);
|
||||
// Continue execution - local broadcast already completed
|
||||
}
|
||||
} else {
|
||||
logger.debug("Redis unavailable - broadcast limited to local node only");
|
||||
logger.debug(
|
||||
"Redis unavailable - broadcast limited to local node only"
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -512,8 +552,10 @@ const verifyToken = async (
|
||||
return null;
|
||||
}
|
||||
|
||||
if (olm.userId) { // this is a user device and we need to check the user token
|
||||
const { session: userSession, user } = await validateSessionToken(userToken);
|
||||
if (olm.userId) {
|
||||
// this is a user device and we need to check the user token
|
||||
const { session: userSession, user } =
|
||||
await validateSessionToken(userToken);
|
||||
if (!userSession || !user) {
|
||||
return null;
|
||||
}
|
||||
@@ -668,7 +710,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
|
||||
url.searchParams.get("token") ||
|
||||
request.headers["sec-websocket-protocol"] ||
|
||||
"";
|
||||
const userToken = url.searchParams.get('userToken') || '';
|
||||
const userToken = url.searchParams.get("userToken") || "";
|
||||
let clientType = url.searchParams.get(
|
||||
"clientType"
|
||||
) as ClientType;
|
||||
@@ -690,7 +732,11 @@ const handleWSUpgrade = (server: HttpServer): void => {
|
||||
return;
|
||||
}
|
||||
|
||||
const tokenPayload = await verifyToken(token, clientType, userToken);
|
||||
const tokenPayload = await verifyToken(
|
||||
token,
|
||||
clientType,
|
||||
userToken
|
||||
);
|
||||
if (!tokenPayload) {
|
||||
logger.debug(
|
||||
"Unauthorized connection attempt: invalid token..."
|
||||
@@ -724,50 +770,68 @@ const handleWSUpgrade = (server: HttpServer): void => {
|
||||
// Add periodic connection state sync to handle Redis disconnections/reconnections
|
||||
const startPeriodicStateSync = (): void => {
|
||||
// Lightweight sync every 5 minutes - just restore our own state
|
||||
setInterval(async () => {
|
||||
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
|
||||
try {
|
||||
await restoreLocalConnectionsToRedis();
|
||||
logger.debug("Periodic connection state sync completed");
|
||||
} catch (error) {
|
||||
logger.error("Error during periodic connection state sync:", error);
|
||||
setInterval(
|
||||
async () => {
|
||||
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
|
||||
try {
|
||||
await restoreLocalConnectionsToRedis();
|
||||
logger.debug("Periodic connection state sync completed");
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
"Error during periodic connection state sync:",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 5 * 60 * 1000); // 5 minutes
|
||||
},
|
||||
5 * 60 * 1000
|
||||
); // 5 minutes
|
||||
|
||||
// Cleanup stale connections every 15 minutes
|
||||
setInterval(async () => {
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
await cleanupStaleConnections();
|
||||
logger.debug("Periodic connection cleanup completed");
|
||||
} catch (error) {
|
||||
logger.error("Error during periodic connection cleanup:", error);
|
||||
setInterval(
|
||||
async () => {
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
await cleanupStaleConnections();
|
||||
logger.debug("Periodic connection cleanup completed");
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
"Error during periodic connection cleanup:",
|
||||
error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 15 * 60 * 1000); // 15 minutes
|
||||
},
|
||||
15 * 60 * 1000
|
||||
); // 15 minutes
|
||||
};
|
||||
|
||||
const cleanupStaleConnections = async (): Promise<void> => {
|
||||
if (!redisManager.isRedisEnabled()) return;
|
||||
|
||||
try {
|
||||
const nodeKeys = await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`) || [];
|
||||
|
||||
const nodeKeys =
|
||||
(await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`)) ||
|
||||
[];
|
||||
|
||||
for (const nodeKey of nodeKeys) {
|
||||
const connections = await redisManager.hgetall(nodeKey);
|
||||
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, '');
|
||||
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, "");
|
||||
const localClients = connectedClients.get(clientId) || [];
|
||||
const localConnectionIds = localClients
|
||||
.filter(client => client.readyState === WebSocket.OPEN)
|
||||
.map(client => client.connectionId)
|
||||
.filter((client) => client.readyState === WebSocket.OPEN)
|
||||
.map((client) => client.connectionId)
|
||||
.filter(Boolean);
|
||||
|
||||
// Remove Redis entries for connections that no longer exist locally
|
||||
for (const [connectionId, timestamp] of Object.entries(connections)) {
|
||||
for (const [connectionId, timestamp] of Object.entries(
|
||||
connections
|
||||
)) {
|
||||
if (!localConnectionIds.includes(connectionId)) {
|
||||
await redisManager.hdel(nodeKey, connectionId);
|
||||
logger.debug(`Cleaned up stale connection: ${connectionId} for client: ${clientId}`);
|
||||
logger.debug(
|
||||
`Cleaned up stale connection: ${connectionId} for client: ${clientId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -776,7 +840,9 @@ const cleanupStaleConnections = async (): Promise<void> => {
|
||||
if (Object.keys(remainingConnections).length === 0) {
|
||||
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
|
||||
await redisManager.del(nodeKey);
|
||||
logger.debug(`Cleaned up empty connection tracking for client: ${clientId}`);
|
||||
logger.debug(
|
||||
`Cleaned up empty connection tracking for client: ${clientId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -789,38 +855,38 @@ if (redisManager.isRedisEnabled()) {
|
||||
initializeRedisSubscription().catch((error) => {
|
||||
logger.error("Failed to initialize Redis subscription:", error);
|
||||
});
|
||||
|
||||
|
||||
// Register recovery callback with Redis manager
|
||||
// When Redis reconnects, each node simply restores its own local state
|
||||
redisManager.onReconnection(async () => {
|
||||
logger.info("Redis reconnected, starting WebSocket state recovery...");
|
||||
await recoverConnectionState();
|
||||
});
|
||||
|
||||
|
||||
// Start periodic state synchronization
|
||||
startPeriodicStateSync();
|
||||
|
||||
|
||||
logger.info(
|
||||
`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`
|
||||
);
|
||||
} else {
|
||||
logger.debug(
|
||||
"WebSocket handler initialized in local mode"
|
||||
);
|
||||
logger.debug("WebSocket handler initialized in local mode");
|
||||
}
|
||||
|
||||
// Disconnect a specific client and force them to reconnect
|
||||
const disconnectClient = async (clientId: string): Promise<boolean> => {
|
||||
const mapKey = getClientMapKey(clientId);
|
||||
const clients = connectedClients.get(mapKey);
|
||||
|
||||
|
||||
if (!clients || clients.length === 0) {
|
||||
logger.debug(`No connections found for client ID: ${clientId}`);
|
||||
return false;
|
||||
}
|
||||
|
||||
logger.info(`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`);
|
||||
|
||||
logger.info(
|
||||
`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`
|
||||
);
|
||||
|
||||
// Close all connections for this client
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
|
||||
Reference in New Issue
Block a user