Merge dev into fix/log-analytics-adjustments

This commit is contained in:
Fred KISSIE
2025-12-10 03:19:14 +01:00
parent 9db2feff77
commit d490cab48c
555 changed files with 9375 additions and 9287 deletions

View File

@@ -19,7 +19,11 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { queryAccessAuditLogsParams, queryAccessAuditLogsQuery, queryAccess } from "./queryAccessAuditLog";
import {
queryAccessAuditLogsParams,
queryAccessAuditLogsQuery,
queryAccess
} from "./queryAccessAuditLog";
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
registry.registerPath({
@@ -67,10 +71,13 @@ export async function exportAccessAuditLogs(
const log = await baseQuery.limit(data.limit).offset(data.offset);
const csvData = generateCSV(log);
res.setHeader('Content-Type', 'text/csv');
res.setHeader('Content-Disposition', `attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`);
res.setHeader("Content-Type", "text/csv");
res.setHeader(
"Content-Disposition",
`attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`
);
return res.send(csvData);
} catch (error) {
logger.error(error);
@@ -78,4 +85,4 @@ export async function exportAccessAuditLogs(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}
}

View File

@@ -19,7 +19,11 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { queryActionAuditLogsParams, queryActionAuditLogsQuery, queryAction } from "./queryActionAuditLog";
import {
queryActionAuditLogsParams,
queryActionAuditLogsQuery,
queryAction
} from "./queryActionAuditLog";
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
registry.registerPath({
@@ -60,17 +64,20 @@ export async function exportActionAuditLogs(
);
}
const data = { ...parsedQuery.data, ...parsedParams.data };
const data = { ...parsedQuery.data, ...parsedParams.data };
const baseQuery = queryAction(data);
const log = await baseQuery.limit(data.limit).offset(data.offset);
const csvData = generateCSV(log);
res.setHeader('Content-Type', 'text/csv');
res.setHeader('Content-Disposition', `attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`);
res.setHeader("Content-Type", "text/csv");
res.setHeader(
"Content-Disposition",
`attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`
);
return res.send(csvData);
} catch (error) {
logger.error(error);
@@ -78,4 +85,4 @@ export async function exportActionAuditLogs(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}
}

View File

@@ -14,4 +14,4 @@
export * from "./queryActionAuditLog";
export * from "./exportActionAuditLog";
export * from "./queryAccessAuditLog";
export * from "./exportAccessAuditLog";
export * from "./exportAccessAuditLog";

View File

@@ -13,4 +13,4 @@
export * from "./transferSession";
export * from "./getSessionTransferToken";
export * from "./quickStart";
export * from "./quickStart";

View File

@@ -395,7 +395,8 @@ export async function quickStart(
.values({
targetId: newTarget[0].targetId,
hcEnabled: false
}).returning();
})
.returning();
// add the new target to the targetIps array
targetIps.push(`${ip}/32`);
@@ -406,7 +407,12 @@ export async function quickStart(
.where(eq(newts.siteId, siteId!))
.limit(1);
await addTargets(newt.newtId, newTarget, newHealthcheck, resource.protocol);
await addTargets(
newt.newtId,
newTarget,
newHealthcheck,
resource.protocol
);
// Set resource pincode if provided
if (pincode) {

View File

@@ -26,8 +26,8 @@ import { getLineItems, getStandardFeaturePriceSet } from "@server/lib/billing";
import { getTierPriceSet, TierId } from "@server/lib/billing/tiers";
const createCheckoutSessionSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function createCheckoutSession(
req: Request,
@@ -72,7 +72,7 @@ export async function createCheckoutSession(
billing_address_collection: "required",
line_items: [
{
price: standardTierPrice, // Use the standard tier
price: standardTierPrice, // Use the standard tier
quantity: 1
},
...getLineItems(getStandardFeaturePriceSet())

View File

@@ -24,8 +24,8 @@ import { fromError } from "zod-validation-error";
import stripe from "#private/lib/stripe";
const createPortalSessionSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function createPortalSession(
req: Request,

View File

@@ -34,8 +34,8 @@ import {
} from "@server/db";
const getOrgSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
registry.registerPath({
method: "get",

View File

@@ -28,8 +28,8 @@ import { FeatureId } from "@server/lib/billing";
import { GetOrgUsageResponse } from "@server/routers/billing/types";
const getOrgSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
registry.registerPath({
method: "get",
@@ -78,11 +78,23 @@ export async function getOrgUsage(
// Get usage for org
const usageData = [];
const siteUptime = await usageService.getUsage(orgId, FeatureId.SITE_UPTIME);
const siteUptime = await usageService.getUsage(
orgId,
FeatureId.SITE_UPTIME
);
const users = await usageService.getUsageDaily(orgId, FeatureId.USERS);
const domains = await usageService.getUsageDaily(orgId, FeatureId.DOMAINS);
const remoteExitNodes = await usageService.getUsageDaily(orgId, FeatureId.REMOTE_EXIT_NODES);
const egressData = await usageService.getUsage(orgId, FeatureId.EGRESS_DATA_MB);
const domains = await usageService.getUsageDaily(
orgId,
FeatureId.DOMAINS
);
const remoteExitNodes = await usageService.getUsageDaily(
orgId,
FeatureId.REMOTE_EXIT_NODES
);
const egressData = await usageService.getUsage(
orgId,
FeatureId.EGRESS_DATA_MB
);
if (siteUptime) {
usageData.push(siteUptime);
@@ -100,7 +112,8 @@ export async function getOrgUsage(
usageData.push(remoteExitNodes);
}
const orgLimits = await db.select()
const orgLimits = await db
.select()
.from(limits)
.where(eq(limits.orgId, orgId));

View File

@@ -31,9 +31,7 @@ export async function handleCustomerDeleted(
return;
}
await db
.delete(customers)
.where(eq(customers.customerId, customer.id));
await db.delete(customers).where(eq(customers.customerId, customer.id));
} catch (error) {
logger.error(
`Error handling customer created event for ID ${customer.id}:`,

View File

@@ -12,7 +12,14 @@
*/
import Stripe from "stripe";
import { subscriptions, db, subscriptionItems, customers, userOrgs, users } from "@server/db";
import {
subscriptions,
db,
subscriptionItems,
customers,
userOrgs,
users
} from "@server/db";
import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
@@ -43,7 +50,6 @@ export async function handleSubscriptionDeleted(
.delete(subscriptionItems)
.where(eq(subscriptionItems.subscriptionId, subscription.id));
// Lookup customer to get orgId
const [customer] = await db
.select()
@@ -58,10 +64,7 @@ export async function handleSubscriptionDeleted(
return;
}
await handleSubscriptionLifesycle(
customer.orgId,
subscription.status
);
await handleSubscriptionLifesycle(customer.orgId, subscription.status);
const [orgUserRes] = await db
.select()

View File

@@ -15,4 +15,4 @@ export * from "./createCheckoutSession";
export * from "./createPortalSession";
export * from "./getOrgSubscription";
export * from "./getOrgUsage";
export * from "./internalGetOrgTier";
export * from "./internalGetOrgTier";

View File

@@ -22,8 +22,8 @@ import { getOrgTierData } from "#private/lib/billing";
import { GetOrgTierResponse } from "@server/routers/billing/types";
const getOrgSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function getOrgTier(
req: Request,

View File

@@ -11,11 +11,18 @@
* This file is not licensed under the AGPLv3.
*/
import { freeLimitSet, limitsService, subscribedLimitSet } from "@server/lib/billing";
import {
freeLimitSet,
limitsService,
subscribedLimitSet
} from "@server/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import logger from "@server/logger";
export async function handleSubscriptionLifesycle(orgId: string, status: string) {
export async function handleSubscriptionLifesycle(
orgId: string,
status: string
) {
switch (status) {
case "active":
await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet);
@@ -42,4 +49,4 @@ export async function handleSubscriptionLifesycle(orgId: string, status: string)
default:
break;
}
}
}

View File

@@ -32,12 +32,13 @@ export async function billingWebhookHandler(
next: NextFunction
): Promise<any> {
let event: Stripe.Event = req.body;
const endpointSecret = privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
const endpointSecret =
privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
if (!endpointSecret) {
logger.warn("Stripe webhook secret is not configured. Webhook events will not be priocessed.");
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "")
logger.warn(
"Stripe webhook secret is not configured. Webhook events will not be priocessed."
);
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, ""));
}
// Only verify the event if you have an endpoint secret defined.
@@ -49,7 +50,10 @@ export async function billingWebhookHandler(
if (!signature) {
logger.info("No stripe signature found in headers.");
return next(
createHttpError(HttpCode.BAD_REQUEST, "No stripe signature found in headers")
createHttpError(
HttpCode.BAD_REQUEST,
"No stripe signature found in headers"
)
);
}
@@ -62,7 +66,10 @@ export async function billingWebhookHandler(
} catch (err) {
logger.error(`Webhook signature verification failed.`, err);
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Webhook signature verification failed")
createHttpError(
HttpCode.UNAUTHORIZED,
"Webhook signature verification failed"
)
);
}
}

View File

@@ -24,10 +24,10 @@ import { registry } from "@server/openApi";
import { GetCertificateResponse } from "@server/routers/certificates/types";
const getCertificateSchema = z.strictObject({
domainId: z.string(),
domain: z.string().min(1).max(255),
orgId: z.string()
});
domainId: z.string(),
domain: z.string().min(1).max(255),
orgId: z.string()
});
async function query(domainId: string, domain: string) {
const [domainRecord] = await db
@@ -42,8 +42,8 @@ async function query(domainId: string, domain: string) {
let existing: any[] = [];
if (domainRecord.type == "ns") {
const domainLevelDown = domain.split('.').slice(1).join('.');
const domainLevelDown = domain.split(".").slice(1).join(".");
existing = await db
.select({
certId: certificates.certId,
@@ -64,7 +64,7 @@ async function query(domainId: string, domain: string) {
eq(certificates.wildcard, true), // only NS domains can have wildcard certs
or(
eq(certificates.domain, domain),
eq(certificates.domain, domainLevelDown),
eq(certificates.domain, domainLevelDown)
)
)
);
@@ -102,8 +102,7 @@ registry.registerPath({
tags: ["Certificate"],
request: {
params: z.object({
domainId: z
.string(),
domainId: z.string(),
domain: z.string().min(1).max(255),
orgId: z.string()
})
@@ -133,7 +132,9 @@ export async function getCertificate(
if (!cert) {
logger.warn(`Certificate not found for domain: ${domainId}`);
return next(createHttpError(HttpCode.NOT_FOUND, "Certificate not found"));
return next(
createHttpError(HttpCode.NOT_FOUND, "Certificate not found")
);
}
return response<GetCertificateResponse>(res, {

View File

@@ -12,4 +12,4 @@
*/
export * from "./getCertificate";
export * from "./restartCertificate";
export * from "./restartCertificate";

View File

@@ -25,9 +25,9 @@ import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const restartCertificateParamsSchema = z.strictObject({
certId: z.string().transform(stoi).pipe(z.int().positive()),
orgId: z.string()
});
certId: z.string().transform(stoi).pipe(z.int().positive()),
orgId: z.string()
});
registry.registerPath({
method: "post",
@@ -36,10 +36,7 @@ registry.registerPath({
tags: ["Certificate"],
request: {
params: z.object({
certId: z
.string()
.transform(stoi)
.pipe(z.int().positive()),
certId: z.string().transform(stoi).pipe(z.int().positive()),
orgId: z.string()
})
},
@@ -94,7 +91,7 @@ export async function restartCertificate(
.set({
status: "pending",
errorMessage: null,
lastRenewalAttempt: Math.floor(Date.now() / 1000)
lastRenewalAttempt: Math.floor(Date.now() / 1000)
})
.where(eq(certificates.certId, certId));

View File

@@ -26,8 +26,8 @@ import { CheckDomainAvailabilityResponse } from "@server/routers/domain/types";
const paramsSchema = z.strictObject({});
const querySchema = z.strictObject({
subdomain: z.string()
});
subdomain: z.string()
});
registry.registerPath({
method: "get",

View File

@@ -12,4 +12,4 @@
*/
export * from "./checkDomainNamespaceAvailability";
export * from "./listDomainNamespaces";
export * from "./listDomainNamespaces";

View File

@@ -26,19 +26,19 @@ import { OpenAPITags, registry } from "@server/openApi";
const paramsSchema = z.strictObject({});
const querySchema = z.strictObject({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
async function query(limit: number, offset: number) {
const res = await db

View File

@@ -1,13 +0,0 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/

View File

@@ -79,86 +79,72 @@ import semver from "semver";
// Zod schemas for request validation
const getResourceByDomainParamsSchema = z.strictObject({
domain: z.string().min(1, "Domain is required")
});
domain: z.string().min(1, "Domain is required")
});
const getUserSessionParamsSchema = z.strictObject({
userSessionId: z.string().min(1, "User session ID is required")
});
userSessionId: z.string().min(1, "User session ID is required")
});
const getUserOrgRoleParamsSchema = z.strictObject({
userId: z.string().min(1, "User ID is required"),
orgId: z.string().min(1, "Organization ID is required")
});
userId: z.string().min(1, "User ID is required"),
orgId: z.string().min(1, "Organization ID is required")
});
const getRoleResourceAccessParamsSchema = z.strictObject({
roleId: z
.string()
.transform(Number)
.pipe(
z.int().positive("Role ID must be a positive integer")
),
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
roleId: z
.string()
.transform(Number)
.pipe(z.int().positive("Role ID must be a positive integer")),
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const getUserResourceAccessParamsSchema = z.strictObject({
userId: z.string().min(1, "User ID is required"),
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
userId: z.string().min(1, "User ID is required"),
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const getResourceRulesParamsSchema = z.strictObject({
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const validateResourceSessionTokenParamsSchema = z.strictObject({
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const validateResourceSessionTokenBodySchema = z.strictObject({
token: z.string().min(1, "Token is required")
});
token: z.string().min(1, "Token is required")
});
const validateResourceAccessTokenBodySchema = z.strictObject({
accessTokenId: z.string().optional(),
resourceId: z.number().optional(),
accessToken: z.string()
});
accessTokenId: z.string().optional(),
resourceId: z.number().optional(),
accessToken: z.string()
});
// Certificates by domains query validation
const getCertificatesByDomainsQuerySchema = z.strictObject({
// Accept domains as string or array (domains or domains[])
domains: z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional(),
// Handle array format from query parameters (domains[])
"domains[]": z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional()
});
// Accept domains as string or array (domains or domains[])
domains: z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional(),
// Handle array format from query parameters (domains[])
"domains[]": z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional()
});
// Type exports for request schemas
export type GetResourceByDomainParams = z.infer<
@@ -566,8 +552,8 @@ hybridRouter.get(
);
const getOrgLoginPageParamsSchema = z.strictObject({
orgId: z.string().min(1)
});
orgId: z.string().min(1)
});
hybridRouter.get(
"/org/:orgId/login-page",
@@ -1408,8 +1394,16 @@ hybridRouter.post(
);
}
const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } =
parsedParams.data;
const {
olmId,
newtId,
ip,
port,
timestamp,
token,
publicKey,
reachableAt
} = parsedParams.data;
const destinations = await updateAndGenerateEndpointDestinations(
olmId,

View File

@@ -18,7 +18,7 @@ import * as logs from "#private/routers/auditLogs";
import {
verifyApiKeyHasAction,
verifyApiKeyIsRoot,
verifyApiKeyOrgAccess,
verifyApiKeyOrgAccess
} from "@server/middlewares";
import {
verifyValidSubscription,
@@ -26,7 +26,10 @@ import {
} from "#private/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import { unauthenticated as ua, authenticated as a } from "@server/routers/integration";
import {
unauthenticated as ua,
authenticated as a
} from "@server/routers/integration";
import { logActionAudit } from "#private/middlewares";
export const unauthenticated = ua;
@@ -37,7 +40,7 @@ authenticated.post(
verifyApiKeyIsRoot, // We are the only ones who can use root key so its fine
verifyApiKeyHasAction(ActionsEnum.sendUsageNotification),
logActionAudit(ActionsEnum.sendUsageNotification),
org.sendUsageNotification,
org.sendUsageNotification
);
authenticated.delete(
@@ -45,7 +48,7 @@ authenticated.delete(
verifyApiKeyIsRoot,
verifyApiKeyHasAction(ActionsEnum.deleteIdp),
logActionAudit(ActionsEnum.deleteIdp),
orgIdp.deleteOrgIdp,
orgIdp.deleteOrgIdp
);
authenticated.get(

View File

@@ -21,8 +21,8 @@ import { z } from "zod";
import { fromError } from "zod-validation-error";
const bodySchema = z.strictObject({
licenseKey: z.string().min(1).max(255)
});
licenseKey: z.string().min(1).max(255)
});
export async function activateLicense(
req: Request,

View File

@@ -24,8 +24,8 @@ import { licenseKey } from "@server/db";
import license from "#private/license/license";
const paramsSchema = z.strictObject({
licenseKey: z.string().min(1).max(255)
});
licenseKey: z.string().min(1).max(255)
});
export async function deleteLicenseKey(
req: Request,

View File

@@ -36,13 +36,13 @@ import { build } from "@server/build";
import { CreateLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
const bodySchema = z.strictObject({
subdomain: z.string().nullable().optional(),
domainId: z.string()
});
subdomain: z.string().nullable().optional(),
domainId: z.string()
});
export type CreateLoginPageBody = z.infer<typeof bodySchema>;
@@ -149,12 +149,20 @@ export async function createLoginPage(
let returned: LoginPage | undefined;
await db.transaction(async (trx) => {
const orgSites = await trx
.select()
.from(sites)
.innerJoin(exitNodes, eq(exitNodes.exitNodeId, sites.exitNodeId))
.where(and(eq(sites.orgId, orgId), eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
.innerJoin(
exitNodes,
eq(exitNodes.exitNodeId, sites.exitNodeId)
)
.where(
and(
eq(sites.orgId, orgId),
eq(exitNodes.type, "gerbil"),
eq(exitNodes.online, true)
)
)
.limit(10);
let exitNodesList = orgSites.map((s) => s.exitNodes);
@@ -163,7 +171,12 @@ export async function createLoginPage(
exitNodesList = await trx
.select()
.from(exitNodes)
.where(and(eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
.where(
and(
eq(exitNodes.type, "gerbil"),
eq(exitNodes.online, true)
)
)
.limit(10);
}

View File

@@ -78,15 +78,11 @@ export async function deleteLoginPage(
// if (!leftoverLinks.length) {
await db
.delete(loginPage)
.where(
eq(loginPage.loginPageId, parsedParams.data.loginPageId)
);
.where(eq(loginPage.loginPageId, parsedParams.data.loginPageId));
await db
.delete(loginPageOrg)
.where(
eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId)
);
.where(eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId));
// }
return response<LoginPage>(res, {

View File

@@ -23,8 +23,8 @@ import { fromError } from "zod-validation-error";
import { GetLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
async function query(orgId: string) {
const [res] = await db

View File

@@ -35,7 +35,8 @@ const paramsSchema = z
})
.strict();
const bodySchema = z.strictObject({
const bodySchema = z
.strictObject({
subdomain: subdomainSchema.nullable().optional(),
domainId: z.string().optional()
})
@@ -86,7 +87,7 @@ export async function updateLoginPage(
const { loginPageId, orgId } = parsedParams.data;
if (build === "saas"){
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
@@ -182,7 +183,10 @@ export async function updateLoginPage(
}
// update the full domain if it has changed
if (fullDomain && fullDomain !== existingLoginPage?.fullDomain) {
if (
fullDomain &&
fullDomain !== existingLoginPage?.fullDomain
) {
await db
.update(loginPage)
.set({ fullDomain })

View File

@@ -23,9 +23,9 @@ import SupportEmail from "@server/emails/templates/SupportEmail";
import config from "@server/lib/config";
const bodySchema = z.strictObject({
body: z.string().min(1),
subject: z.string().min(1).max(255)
});
body: z.string().min(1),
subject: z.string().min(1).max(255)
});
export async function sendSupportEmail(
req: Request,

View File

@@ -11,4 +11,4 @@
* This file is not licensed under the AGPLv3.
*/
export * from "./sendUsageNotifications";
export * from "./sendUsageNotifications";

View File

@@ -35,10 +35,12 @@ const sendUsageNotificationBodySchema = z.object({
notificationType: z.enum(["approaching_70", "approaching_90", "reached"]),
limitName: z.string(),
currentUsage: z.number(),
usageLimit: z.number(),
usageLimit: z.number()
});
type SendUsageNotificationRequest = z.infer<typeof sendUsageNotificationBodySchema>;
type SendUsageNotificationRequest = z.infer<
typeof sendUsageNotificationBodySchema
>;
export type SendUsageNotificationResponse = {
success: boolean;
@@ -97,17 +99,13 @@ async function getOrgAdmins(orgId: string) {
.where(
and(
eq(userOrgs.orgId, orgId),
or(
eq(userOrgs.isOwner, true),
eq(roles.isAdmin, true)
)
or(eq(userOrgs.isOwner, true), eq(roles.isAdmin, true))
)
);
// Filter to only include users with verified emails
const orgAdmins = admins.filter(admin =>
admin.email &&
admin.email.length > 0
const orgAdmins = admins.filter(
(admin) => admin.email && admin.email.length > 0
);
return orgAdmins;
@@ -119,7 +117,9 @@ export async function sendUsageNotification(
next: NextFunction
): Promise<any> {
try {
const parsedParams = sendUsageNotificationParamsSchema.safeParse(req.params);
const parsedParams = sendUsageNotificationParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
@@ -140,12 +140,8 @@ export async function sendUsageNotification(
}
const { orgId } = parsedParams.data;
const {
notificationType,
limitName,
currentUsage,
usageLimit,
} = parsedBody.data;
const { notificationType, limitName, currentUsage, usageLimit } =
parsedBody.data;
// Verify organization exists
const org = await db
@@ -192,7 +188,10 @@ export async function sendUsageNotification(
let template;
let subject;
if (notificationType === "approaching_70" || notificationType === "approaching_90") {
if (
notificationType === "approaching_70" ||
notificationType === "approaching_90"
) {
template = NotifyUsageLimitApproaching({
email: admin.email,
limitName,
@@ -220,10 +219,15 @@ export async function sendUsageNotification(
emailsSent++;
adminEmails.push(admin.email);
logger.info(`Usage notification sent to admin ${admin.email} for org ${orgId}`);
logger.info(
`Usage notification sent to admin ${admin.email} for org ${orgId}`
);
} catch (emailError) {
logger.error(`Failed to send usage notification to ${admin.email}:`, emailError);
logger.error(
`Failed to send usage notification to ${admin.email}:`,
emailError
);
// Continue with other admins even if one fails
}
}
@@ -239,11 +243,13 @@ export async function sendUsageNotification(
message: `Usage notifications sent to ${emailsSent} administrators`,
status: HttpCode.OK
});
} catch (error) {
logger.error("Error sending usage notifications:", error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to send usage notifications")
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to send usage notifications"
)
);
}
}

View File

@@ -32,19 +32,19 @@ import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types";
const paramsSchema = z.strictObject({ orgId: z.string().nonempty() });
const bodySchema = z.strictObject({
name: z.string().nonempty(),
clientId: z.string().nonempty(),
clientSecret: z.string().nonempty(),
authUrl: z.url(),
tokenUrl: z.url(),
identifierPath: z.string().nonempty(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().nonempty(),
autoProvision: z.boolean().optional(),
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
roleMapping: z.string().optional()
});
name: z.string().nonempty(),
clientId: z.string().nonempty(),
clientSecret: z.string().nonempty(),
authUrl: z.url(),
tokenUrl: z.url(),
identifierPath: z.string().nonempty(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().nonempty(),
autoProvision: z.boolean().optional(),
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
roleMapping: z.string().optional()
});
// registry.registerPath({
// method: "put",
@@ -158,7 +158,10 @@ export async function createOrgOidcIdp(
});
});
const redirectUrl = await generateOidcRedirectUrl(idpId as number, orgId);
const redirectUrl = await generateOidcRedirectUrl(
idpId as number,
orgId
);
return response<CreateOrgIdpResponse>(res, {
data: {

View File

@@ -66,12 +66,7 @@ export async function deleteOrgIdp(
.where(eq(idp.idpId, idpId));
if (!existingIdp) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"IdP not found"
)
);
return next(createHttpError(HttpCode.NOT_FOUND, "IdP not found"));
}
// Delete the IDP and its related records in a transaction
@@ -82,14 +77,10 @@ export async function deleteOrgIdp(
.where(eq(idpOidcConfig.idpId, idpId));
// Delete IDP-org mappings
await trx
.delete(idpOrg)
.where(eq(idpOrg.idpId, idpId));
await trx.delete(idpOrg).where(eq(idpOrg.idpId, idpId));
// Delete the IDP itself
await trx
.delete(idp)
.where(eq(idp.idpId, idpId));
await trx.delete(idp).where(eq(idp.idpId, idpId));
});
return response<null>(res, {

View File

@@ -93,7 +93,10 @@ export async function getOrgIdp(
idpRes.idpOidcConfig!.clientId = decrypt(clientId, key);
}
const redirectUrl = await generateOidcRedirectUrl(idpRes.idp.idpId, orgId);
const redirectUrl = await generateOidcRedirectUrl(
idpRes.idp.idpId,
orgId
);
return response<GetOrgIdpResponse>(res, {
data: {

View File

@@ -15,4 +15,4 @@ export * from "./createOrgOidcIdp";
export * from "./getOrgIdp";
export * from "./listOrgIdps";
export * from "./updateOrgOidcIdp";
export * from "./deleteOrgIdp";
export * from "./deleteOrgIdp";

View File

@@ -25,23 +25,23 @@ import { OpenAPITags, registry } from "@server/openApi";
import { ListOrgIdpsResponse } from "@server/routers/orgIdp/types";
const querySchema = z.strictObject({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
const paramsSchema = z.strictObject({
orgId: z.string().nonempty()
});
orgId: z.string().nonempty()
});
async function query(orgId: string, limit: number, offset: number) {
const res = await db

View File

@@ -36,18 +36,18 @@ const paramsSchema = z
.strict();
const bodySchema = z.strictObject({
name: z.string().optional(),
clientId: z.string().optional(),
clientSecret: z.string().optional(),
authUrl: z.string().optional(),
tokenUrl: z.string().optional(),
identifierPath: z.string().optional(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().optional(),
autoProvision: z.boolean().optional(),
roleMapping: z.string().optional()
});
name: z.string().optional(),
clientId: z.string().optional(),
clientSecret: z.string().optional(),
authUrl: z.string().optional(),
tokenUrl: z.string().optional(),
identifierPath: z.string().optional(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().optional(),
autoProvision: z.boolean().optional(),
roleMapping: z.string().optional()
});
export type UpdateOrgIdpResponse = {
idpId: number;

View File

@@ -13,4 +13,4 @@
export * from "./reGenerateClientSecret";
export * from "./reGenerateSiteSecret";
export * from "./reGenerateExitNodeSecret";
export * from "./reGenerateExitNodeSecret";

View File

@@ -123,7 +123,10 @@ export async function reGenerateClientSecret(
};
// Don't await this to prevent blocking the response
sendToClient(existingOlms[0].olmId, payload).catch((error) => {
logger.error("Failed to send termination message to olm:", error);
logger.error(
"Failed to send termination message to olm:",
error
);
});
disconnectClient(existingOlms[0].olmId).catch((error) => {
@@ -133,7 +136,7 @@ export async function reGenerateClientSecret(
return response(res, {
data: {
olmId: existingOlms[0].olmId,
olmId: existingOlms[0].olmId
},
success: true,
error: false,

View File

@@ -12,7 +12,14 @@
*/
import { NextFunction, Request, Response } from "express";
import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg, RemoteExitNode } from "@server/db";
import {
db,
exitNodes,
exitNodeOrgs,
ExitNode,
ExitNodeOrg,
RemoteExitNode
} from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { remoteExitNodes } from "@server/db";
@@ -91,14 +98,15 @@ export async function reGenerateExitNodeSecret(
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(existingRemoteExitNode.remoteExitNodeId, payload).catch(
(error) => {
logger.error(
"Failed to send termination message to remote exit node:",
error
);
}
);
sendToClient(
existingRemoteExitNode.remoteExitNodeId,
payload
).catch((error) => {
logger.error(
"Failed to send termination message to remote exit node:",
error
);
});
disconnectClient(existingRemoteExitNode.remoteExitNodeId).catch(
(error) => {

View File

@@ -80,7 +80,7 @@ export async function reGenerateSiteSecret(
const secretHash = await hashPassword(secret);
// get the newt to verify it exists
const existingNewts = await db
const existingNewts = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId));
@@ -120,15 +120,20 @@ export async function reGenerateSiteSecret(
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(existingNewts[0].newtId, payload).catch((error) => {
logger.error(
"Failed to send termination message to newt:",
error
);
});
sendToClient(existingNewts[0].newtId, payload).catch(
(error) => {
logger.error(
"Failed to send termination message to newt:",
error
);
}
);
disconnectClient(existingNewts[0].newtId).catch((error) => {
logger.error("Failed to disconnect newt after re-key:", error);
logger.error(
"Failed to disconnect newt after re-key:",
error
);
});
}

View File

@@ -36,9 +36,9 @@ export const paramsSchema = z.object({
});
const bodySchema = z.strictObject({
remoteExitNodeId: z.string().length(15),
secret: z.string().length(48)
});
remoteExitNodeId: z.string().length(15),
secret: z.string().length(48)
});
export type CreateRemoteExitNodeBody = z.infer<typeof bodySchema>;

View File

@@ -25,9 +25,9 @@ import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
const paramsSchema = z.strictObject({
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
export async function deleteRemoteExitNode(
req: Request,

View File

@@ -24,9 +24,9 @@ import { fromError } from "zod-validation-error";
import { GetRemoteExitNodeResponse } from "@server/routers/remoteExitNode/types";
const getRemoteExitNodeSchema = z.strictObject({
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
async function query(remoteExitNodeId: string) {
const [remoteExitNode] = await db

View File

@@ -55,7 +55,8 @@ export async function getRemoteExitNodeToken(
try {
if (token) {
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
const { session, remoteExitNode } =
await validateRemoteExitNodeSessionToken(token);
if (session) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
@@ -103,7 +104,10 @@ export async function getRemoteExitNodeToken(
}
const resToken = generateSessionToken();
await createRemoteExitNodeSession(resToken, existingRemoteExitNode.remoteExitNodeId);
await createRemoteExitNodeSession(
resToken,
existingRemoteExitNode.remoteExitNodeId
);
// logger.debug(`Created RemoteExitNode token response: ${JSON.stringify(resToken)}`);

View File

@@ -33,7 +33,9 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = Math.floor((Date.now() - OFFLINE_THRESHOLD_MS) / 1000);
const twoMinutesAgo = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
);
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
const newlyOfflineNodes = await db
@@ -48,11 +50,13 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
isNull(exitNodes.lastPing)
)
)
).returning();
)
.returning();
// Update the sites to offline if they have not pinged either
const exitNodeIds = newlyOfflineNodes.map(node => node.exitNodeId);
const exitNodeIds = newlyOfflineNodes.map(
(node) => node.exitNodeId
);
const sitesOnNode = await db
.select()
@@ -77,7 +81,6 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
.where(eq(sites.siteId, site.siteId));
}
}
} catch (error) {
logger.error("Error in offline checker interval", { error });
}
@@ -100,7 +103,9 @@ export const stopRemoteExitNodeOfflineChecker = (): void => {
/**
* Handles ping messages from clients and responds with pong
*/
export const handleRemoteExitNodePingMessage: MessageHandler = async (context) => {
export const handleRemoteExitNodePingMessage: MessageHandler = async (
context
) => {
const { message, client: c, sendToClient } = context;
const remoteExitNode = c as RemoteExitNode;
@@ -120,7 +125,7 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
.update(exitNodes)
.set({
lastPing: Math.floor(Date.now() / 1000),
online: true,
online: true
})
.where(eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId));
} catch (error) {
@@ -131,7 +136,7 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
message: {
type: "pong",
data: {
timestamp: new Date().toISOString(),
timestamp: new Date().toISOString()
}
},
broadcast: false,

View File

@@ -29,7 +29,8 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
return;
}
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } = message.data;
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } =
message.data;
if (!remoteExitNodeVersion) {
logger.warn("Remote exit node version not found");
@@ -39,7 +40,10 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
// update the version
await db
.update(remoteExitNodes)
.set({ version: remoteExitNodeVersion, secondaryVersion: remoteExitNodeSecondaryVersion })
.set({
version: remoteExitNodeVersion,
secondaryVersion: remoteExitNodeSecondaryVersion
})
.where(
eq(
remoteExitNodes.remoteExitNodeId,

View File

@@ -24,8 +24,8 @@ import { fromError } from "zod-validation-error";
import { ListRemoteExitNodesResponse } from "@server/routers/remoteExitNode/types";
const listRemoteExitNodesParamsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
const listRemoteExitNodesSchema = z.object({
limit: z

View File

@@ -22,8 +22,8 @@ import { z } from "zod";
import { PickRemoteExitNodeDefaultsResponse } from "@server/routers/remoteExitNode/types";
const paramsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function pickRemoteExitNodeDefaults(
req: Request,

View File

@@ -38,7 +38,9 @@ export async function quickStartRemoteExitNode(
next: NextFunction
): Promise<any> {
try {
const parsedBody = quickStartRemoteExitNodeBodySchema.safeParse(req.body);
const parsedBody = quickStartRemoteExitNodeBodySchema.safeParse(
req.body
);
if (!parsedBody.success) {
return next(
createHttpError(

View File

@@ -11,4 +11,4 @@
* This file is not licensed under the AGPLv3.
*/
export * from "./ws";
export * from "./ws";

View File

@@ -23,4 +23,4 @@ export const messageHandlers: Record<string, MessageHandler> = {
"remoteExitNode/ping": handleRemoteExitNodePingMessage
};
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes

View File

@@ -37,7 +37,14 @@ import { validateRemoteExitNodeSessionToken } from "#private/auth/sessions/remot
import { rateLimitService } from "#private/lib/rateLimit";
import { messageHandlers } from "@server/routers/ws/messageHandlers";
import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers";
import { AuthenticatedWebSocket, ClientType, WSMessage, TokenPayload, WebSocketRequest, RedisMessage } from "@server/routers/ws";
import {
AuthenticatedWebSocket,
ClientType,
WSMessage,
TokenPayload,
WebSocketRequest,
RedisMessage
} from "@server/routers/ws";
import { validateSessionToken } from "@server/auth/sessions/app";
// Merge public and private message handlers
@@ -55,9 +62,9 @@ const processMessage = async (
try {
const message: WSMessage = JSON.parse(data.toString());
logger.debug(
`Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
);
// logger.debug(
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
// );
if (!message.type || typeof message.type !== "string") {
throw new Error("Invalid message format: missing or invalid type");
@@ -216,7 +223,7 @@ const initializeRedisSubscription = async (): Promise<void> => {
// Each node is responsible for restoring its own connection state to Redis
// This approach is more efficient than cross-node coordination because:
// 1. Each node knows its own connections (source of truth)
// 2. No network overhead from broadcasting state between nodes
// 2. No network overhead from broadcasting state between nodes
// 3. No race conditions from simultaneous updates
// 4. Redis becomes eventually consistent as each node restores independently
// 5. Simpler logic with better fault tolerance
@@ -233,8 +240,10 @@ const recoverConnectionState = async (): Promise<void> => {
// Each node simply restores its own local connections to Redis
// This is the source of truth - no need for cross-node coordination
await restoreLocalConnectionsToRedis();
logger.info("Redis connection state recovery completed - restored local state");
logger.info(
"Redis connection state recovery completed - restored local state"
);
} catch (error) {
logger.error("Error during Redis recovery:", error);
} finally {
@@ -251,8 +260,10 @@ const restoreLocalConnectionsToRedis = async (): Promise<void> => {
try {
// Restore all current local connections to Redis
for (const [clientId, clients] of connectedClients.entries()) {
const validClients = clients.filter(client => client.readyState === WebSocket.OPEN);
const validClients = clients.filter(
(client) => client.readyState === WebSocket.OPEN
);
if (validClients.length > 0) {
// Add this node to the client's connection list
await redisManager.sadd(getConnectionsKey(clientId), NODE_ID);
@@ -303,7 +314,10 @@ const addClient = async (
Date.now().toString()
);
} catch (error) {
logger.error("Failed to add client to Redis tracking (connection still functional locally):", error);
logger.error(
"Failed to add client to Redis tracking (connection still functional locally):",
error
);
}
}
@@ -326,9 +340,14 @@ const removeClient = async (
if (redisManager.isRedisEnabled()) {
try {
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
await redisManager.del(getNodeConnectionsKey(NODE_ID, clientId));
await redisManager.del(
getNodeConnectionsKey(NODE_ID, clientId)
);
} catch (error) {
logger.error("Failed to remove client from Redis tracking (cleanup will occur on recovery):", error);
logger.error(
"Failed to remove client from Redis tracking (cleanup will occur on recovery):",
error
);
}
}
@@ -345,7 +364,10 @@ const removeClient = async (
ws.connectionId
);
} catch (error) {
logger.error("Failed to remove specific connection from Redis tracking:", error);
logger.error(
"Failed to remove specific connection from Redis tracking:",
error
);
}
}
@@ -372,7 +394,9 @@ const sendToClientLocal = async (
}
});
logger.debug(`sendToClient: Message type ${message.type} sent to clientId ${clientId}`);
logger.debug(
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
);
return true;
};
@@ -411,14 +435,22 @@ const sendToClient = async (
fromNodeId: NODE_ID
};
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
} catch (error) {
logger.error("Failed to send message via Redis, message may be lost:", error);
logger.error(
"Failed to send message via Redis, message may be lost:",
error
);
// Continue execution - local delivery already attempted
}
} else if (!localSent && !redisManager.isRedisEnabled()) {
// Redis is disabled or unavailable - log that we couldn't deliver to remote nodes
logger.debug(`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`);
logger.debug(
`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`
);
}
return localSent;
@@ -441,13 +473,21 @@ const broadcastToAllExcept = async (
fromNodeId: NODE_ID
};
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
} catch (error) {
logger.error("Failed to broadcast message via Redis, remote nodes may not receive it:", error);
logger.error(
"Failed to broadcast message via Redis, remote nodes may not receive it:",
error
);
// Continue execution - local broadcast already completed
}
} else {
logger.debug("Redis unavailable - broadcast limited to local node only");
logger.debug(
"Redis unavailable - broadcast limited to local node only"
);
}
};
@@ -512,8 +552,10 @@ const verifyToken = async (
return null;
}
if (olm.userId) { // this is a user device and we need to check the user token
const { session: userSession, user } = await validateSessionToken(userToken);
if (olm.userId) {
// this is a user device and we need to check the user token
const { session: userSession, user } =
await validateSessionToken(userToken);
if (!userSession || !user) {
return null;
}
@@ -668,7 +710,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
url.searchParams.get("token") ||
request.headers["sec-websocket-protocol"] ||
"";
const userToken = url.searchParams.get('userToken') || '';
const userToken = url.searchParams.get("userToken") || "";
let clientType = url.searchParams.get(
"clientType"
) as ClientType;
@@ -690,7 +732,11 @@ const handleWSUpgrade = (server: HttpServer): void => {
return;
}
const tokenPayload = await verifyToken(token, clientType, userToken);
const tokenPayload = await verifyToken(
token,
clientType,
userToken
);
if (!tokenPayload) {
logger.debug(
"Unauthorized connection attempt: invalid token..."
@@ -724,50 +770,68 @@ const handleWSUpgrade = (server: HttpServer): void => {
// Add periodic connection state sync to handle Redis disconnections/reconnections
const startPeriodicStateSync = (): void => {
// Lightweight sync every 5 minutes - just restore our own state
setInterval(async () => {
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
try {
await restoreLocalConnectionsToRedis();
logger.debug("Periodic connection state sync completed");
} catch (error) {
logger.error("Error during periodic connection state sync:", error);
setInterval(
async () => {
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
try {
await restoreLocalConnectionsToRedis();
logger.debug("Periodic connection state sync completed");
} catch (error) {
logger.error(
"Error during periodic connection state sync:",
error
);
}
}
}
}, 5 * 60 * 1000); // 5 minutes
},
5 * 60 * 1000
); // 5 minutes
// Cleanup stale connections every 15 minutes
setInterval(async () => {
if (redisManager.isRedisEnabled()) {
try {
await cleanupStaleConnections();
logger.debug("Periodic connection cleanup completed");
} catch (error) {
logger.error("Error during periodic connection cleanup:", error);
setInterval(
async () => {
if (redisManager.isRedisEnabled()) {
try {
await cleanupStaleConnections();
logger.debug("Periodic connection cleanup completed");
} catch (error) {
logger.error(
"Error during periodic connection cleanup:",
error
);
}
}
}
}, 15 * 60 * 1000); // 15 minutes
},
15 * 60 * 1000
); // 15 minutes
};
const cleanupStaleConnections = async (): Promise<void> => {
if (!redisManager.isRedisEnabled()) return;
try {
const nodeKeys = await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`) || [];
const nodeKeys =
(await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`)) ||
[];
for (const nodeKey of nodeKeys) {
const connections = await redisManager.hgetall(nodeKey);
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, '');
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, "");
const localClients = connectedClients.get(clientId) || [];
const localConnectionIds = localClients
.filter(client => client.readyState === WebSocket.OPEN)
.map(client => client.connectionId)
.filter((client) => client.readyState === WebSocket.OPEN)
.map((client) => client.connectionId)
.filter(Boolean);
// Remove Redis entries for connections that no longer exist locally
for (const [connectionId, timestamp] of Object.entries(connections)) {
for (const [connectionId, timestamp] of Object.entries(
connections
)) {
if (!localConnectionIds.includes(connectionId)) {
await redisManager.hdel(nodeKey, connectionId);
logger.debug(`Cleaned up stale connection: ${connectionId} for client: ${clientId}`);
logger.debug(
`Cleaned up stale connection: ${connectionId} for client: ${clientId}`
);
}
}
@@ -776,7 +840,9 @@ const cleanupStaleConnections = async (): Promise<void> => {
if (Object.keys(remainingConnections).length === 0) {
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
await redisManager.del(nodeKey);
logger.debug(`Cleaned up empty connection tracking for client: ${clientId}`);
logger.debug(
`Cleaned up empty connection tracking for client: ${clientId}`
);
}
}
} catch (error) {
@@ -789,38 +855,38 @@ if (redisManager.isRedisEnabled()) {
initializeRedisSubscription().catch((error) => {
logger.error("Failed to initialize Redis subscription:", error);
});
// Register recovery callback with Redis manager
// When Redis reconnects, each node simply restores its own local state
redisManager.onReconnection(async () => {
logger.info("Redis reconnected, starting WebSocket state recovery...");
await recoverConnectionState();
});
// Start periodic state synchronization
startPeriodicStateSync();
logger.info(
`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`
);
} else {
logger.debug(
"WebSocket handler initialized in local mode"
);
logger.debug("WebSocket handler initialized in local mode");
}
// Disconnect a specific client and force them to reconnect
const disconnectClient = async (clientId: string): Promise<boolean> => {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
if (!clients || clients.length === 0) {
logger.debug(`No connections found for client ID: ${clientId}`);
return false;
}
logger.info(`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`);
logger.info(
`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`
);
// Close all connections for this client
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {