Merge branch 'dev' into multi-role

This commit is contained in:
miloschwartz
2026-03-24 22:01:13 -07:00
266 changed files with 7813 additions and 5880 deletions

View File

@@ -32,7 +32,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/logs/access/export",
description: "Export the access audit log for an organization as CSV",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Logs],
request: {
query: queryAccessAuditLogsQuery,
params: queryAccessAuditLogsParams

View File

@@ -32,7 +32,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/logs/action/export",
description: "Export the action audit log for an organization as CSV",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Logs],
request: {
query: queryActionAuditLogsQuery,
params: queryActionAuditLogsParams

View File

@@ -11,11 +11,11 @@
* This file is not licensed under the AGPLv3.
*/
import { accessAuditLog, db, resources } from "@server/db";
import { accessAuditLog, logsDb, resources, db, primaryDb } from "@server/db";
import { registry } from "@server/openApi";
import { NextFunction } from "express";
import { Request, Response } from "express";
import { eq, gt, lt, and, count, desc } from "drizzle-orm";
import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm";
import { OpenAPITags } from "@server/openApi";
import { z } from "zod";
import createHttpError from "http-errors";
@@ -115,15 +115,13 @@ function getWhere(data: Q) {
}
export function queryAccess(data: Q) {
return db
return logsDb
.select({
orgId: accessAuditLog.orgId,
action: accessAuditLog.action,
actorType: accessAuditLog.actorType,
actorId: accessAuditLog.actorId,
resourceId: accessAuditLog.resourceId,
resourceName: resources.name,
resourceNiceId: resources.niceId,
ip: accessAuditLog.ip,
location: accessAuditLog.location,
userAgent: accessAuditLog.userAgent,
@@ -133,16 +131,46 @@ export function queryAccess(data: Q) {
actor: accessAuditLog.actor
})
.from(accessAuditLog)
.leftJoin(
resources,
eq(accessAuditLog.resourceId, resources.resourceId)
)
.where(getWhere(data))
.orderBy(desc(accessAuditLog.timestamp), desc(accessAuditLog.id));
}
async function enrichWithResourceDetails(logs: Awaited<ReturnType<typeof queryAccess>>) {
// If logs database is the same as main database, we can do a join
// Otherwise, we need to fetch resource details separately
const resourceIds = logs
.map(log => log.resourceId)
.filter((id): id is number => id !== null && id !== undefined);
if (resourceIds.length === 0) {
return logs.map(log => ({ ...log, resourceName: null, resourceNiceId: null }));
}
// Fetch resource details from main database
const resourceDetails = await primaryDb
.select({
resourceId: resources.resourceId,
name: resources.name,
niceId: resources.niceId
})
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
// Create a map for quick lookup
const resourceMap = new Map(
resourceDetails.map(r => [r.resourceId, { name: r.name, niceId: r.niceId }])
);
// Enrich logs with resource details
return logs.map(log => ({
...log,
resourceName: log.resourceId ? resourceMap.get(log.resourceId)?.name ?? null : null,
resourceNiceId: log.resourceId ? resourceMap.get(log.resourceId)?.niceId ?? null : null
}));
}
export function countAccessQuery(data: Q) {
const countQuery = db
const countQuery = logsDb
.select({ count: count() })
.from(accessAuditLog)
.where(getWhere(data));
@@ -161,7 +189,7 @@ async function queryUniqueFilterAttributes(
);
// Get unique actors
const uniqueActors = await db
const uniqueActors = await logsDb
.selectDistinct({
actor: accessAuditLog.actor
})
@@ -169,7 +197,7 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
// Get unique locations
const uniqueLocations = await db
const uniqueLocations = await logsDb
.selectDistinct({
locations: accessAuditLog.location
})
@@ -177,25 +205,40 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
// Get unique resources with names
const uniqueResources = await db
const uniqueResources = await logsDb
.selectDistinct({
id: accessAuditLog.resourceId,
name: resources.name
id: accessAuditLog.resourceId
})
.from(accessAuditLog)
.leftJoin(
resources,
eq(accessAuditLog.resourceId, resources.resourceId)
)
.where(baseConditions);
// Fetch resource names from main database for the unique resource IDs
const resourceIds = uniqueResources
.map(row => row.id)
.filter((id): id is number => id !== null);
let resourcesWithNames: Array<{ id: number; name: string | null }> = [];
if (resourceIds.length > 0) {
const resourceDetails = await primaryDb
.select({
resourceId: resources.resourceId,
name: resources.name
})
.from(resources)
.where(inArray(resources.resourceId, resourceIds));
resourcesWithNames = resourceDetails.map(r => ({
id: r.resourceId,
name: r.name
}));
}
return {
actors: uniqueActors
.map((row) => row.actor)
.filter((actor): actor is string => actor !== null),
resources: uniqueResources.filter(
(row): row is { id: number; name: string | null } => row.id !== null
),
resources: resourcesWithNames,
locations: uniqueLocations
.map((row) => row.locations)
.filter((location): location is string => location !== null)
@@ -206,7 +249,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/logs/access",
description: "Query the access audit log for an organization",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Logs],
request: {
query: queryAccessAuditLogsQuery,
params: queryAccessAuditLogsParams
@@ -243,7 +286,10 @@ export async function queryAccessAuditLogs(
const baseQuery = queryAccess(data);
const log = await baseQuery.limit(data.limit).offset(data.offset);
const logsRaw = await baseQuery.limit(data.limit).offset(data.offset);
// Enrich with resource details (handles cross-database scenario)
const log = await enrichWithResourceDetails(logsRaw);
const totalCountResult = await countAccessQuery(data);
const totalCount = totalCountResult[0].count;

View File

@@ -11,7 +11,7 @@
* This file is not licensed under the AGPLv3.
*/
import { actionAuditLog, db } from "@server/db";
import { actionAuditLog, logsDb } from "@server/db";
import { registry } from "@server/openApi";
import { NextFunction } from "express";
import { Request, Response } from "express";
@@ -97,7 +97,7 @@ function getWhere(data: Q) {
}
export function queryAction(data: Q) {
return db
return logsDb
.select({
orgId: actionAuditLog.orgId,
action: actionAuditLog.action,
@@ -113,7 +113,7 @@ export function queryAction(data: Q) {
}
export function countActionQuery(data: Q) {
const countQuery = db
const countQuery = logsDb
.select({ count: count() })
.from(actionAuditLog)
.where(getWhere(data));
@@ -132,14 +132,14 @@ async function queryUniqueFilterAttributes(
);
// Get unique actors
const uniqueActors = await db
const uniqueActors = await logsDb
.selectDistinct({
actor: actionAuditLog.actor
})
.from(actionAuditLog)
.where(baseConditions);
const uniqueActions = await db
const uniqueActions = await logsDb
.selectDistinct({
action: actionAuditLog.action
})
@@ -160,7 +160,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/logs/action",
description: "Query the action audit log for an organization",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Logs],
request: {
query: queryActionAuditLogsQuery,
params: queryActionAuditLogsParams

View File

@@ -31,16 +31,16 @@ const getOrgSchema = z.strictObject({
orgId: z.string()
});
registry.registerPath({
method: "get",
path: "/org/{orgId}/billing/usage",
description: "Get an organization's billing usage",
tags: [OpenAPITags.Org],
request: {
params: getOrgSchema
},
responses: {}
});
// registry.registerPath({
// method: "get",
// path: "/org/{orgId}/billing/usage",
// description: "Get an organization's billing usage",
// tags: [OpenAPITags.Org],
// request: {
// params: getOrgSchema
// },
// responses: {}
// });
export async function getOrgUsage(
req: Request,

View File

@@ -24,7 +24,6 @@ import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import stripe from "#private/lib/stripe";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
import { getSubType } from "./getSubType";
import privateConfig from "#private/lib/config";
import { getLicensePriceSet, LicenseId } from "@server/lib/billing/licenses";
@@ -172,7 +171,7 @@ export async function handleSubscriptionCreated(
const email = orgUserRes.user.email;
if (email) {
moveEmailToAudience(email, AudienceIds.Subscribed);
// TODO: update user in Sendy
}
}
} else if (type === "license") {

View File

@@ -23,7 +23,6 @@ import {
import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
import { getSubType } from "./getSubType";
import stripe from "#private/lib/stripe";
import privateConfig from "#private/lib/config";
@@ -109,7 +108,7 @@ export async function handleSubscriptionDeleted(
const email = orgUserRes.user.email;
if (email) {
moveEmailToAudience(email, AudienceIds.Churned);
// TODO: update user in Sendy
}
}
} else if (type === "license") {

View File

@@ -480,9 +480,9 @@ authenticated.get(
authenticated.post(
"/re-key/:clientId/regenerate-client-secret",
verifyClientAccess, // this is first to set the org id
verifyValidLicense,
verifyValidSubscription(tierMatrix.rotateCredentials),
verifyClientAccess, // this is first to set the org id
verifyLimits,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateClientSecret
@@ -490,9 +490,9 @@ authenticated.post(
authenticated.post(
"/re-key/:siteId/regenerate-site-secret",
verifySiteAccess, // this is first to set the org id
verifyValidLicense,
verifyValidSubscription(tierMatrix.rotateCredentials),
verifySiteAccess, // this is first to set the org id
verifyLimits,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateSiteSecret
@@ -515,6 +515,6 @@ authenticated.post(
verifyOrgAccess,
verifyLimits,
verifyUserHasAction(ActionsEnum.signSshKey),
logActionAudit(ActionsEnum.signSshKey),
// logActionAudit(ActionsEnum.signSshKey), // it is handled inside of the function below so we can include more metadata
ssh.signSshKey
);

View File

@@ -15,6 +15,7 @@ import { verifySessionRemoteExitNodeMiddleware } from "#private/middlewares/veri
import { Router } from "express";
import {
db,
logsDb,
exitNodes,
Resource,
ResourcePassword,
@@ -81,6 +82,7 @@ import { verifyResourceAccessToken } from "@server/auth/verifyResourceAccessToke
import semver from "semver";
import { maxmindAsnLookup } from "@server/db/maxmindAsn";
import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy";
import { sanitizeString } from "@server/lib/sanitize";
// Zod schemas for request validation
const getResourceByDomainParamsSchema = z.strictObject({
@@ -1859,24 +1861,24 @@ hybridRouter.post(
})
.map((logEntry) => ({
timestamp: logEntry.timestamp,
orgId: logEntry.orgId,
actorType: logEntry.actorType,
actor: logEntry.actor,
actorId: logEntry.actorId,
metadata: logEntry.metadata,
orgId: sanitizeString(logEntry.orgId),
actorType: sanitizeString(logEntry.actorType),
actor: sanitizeString(logEntry.actor),
actorId: sanitizeString(logEntry.actorId),
metadata: sanitizeString(logEntry.metadata),
action: logEntry.action,
resourceId: logEntry.resourceId,
reason: logEntry.reason,
location: logEntry.location,
location: sanitizeString(logEntry.location),
// userAgent: data.userAgent, // TODO: add this
// headers: data.body.headers,
// query: data.body.query,
originalRequestURL: logEntry.originalRequestURL,
scheme: logEntry.scheme,
host: logEntry.host,
path: logEntry.path,
method: logEntry.method,
ip: logEntry.ip,
originalRequestURL: sanitizeString(logEntry.originalRequestURL) ?? "",
scheme: sanitizeString(logEntry.scheme) ?? "",
host: sanitizeString(logEntry.host) ?? "",
path: sanitizeString(logEntry.path) ?? "",
method: sanitizeString(logEntry.method) ?? "",
ip: sanitizeString(logEntry.ip),
tls: logEntry.tls
}));
@@ -1884,7 +1886,7 @@ hybridRouter.post(
const batchSize = 100;
for (let i = 0; i < logEntries.length; i += batchSize) {
const batch = logEntries.slice(i, i + batchSize);
await db.insert(requestAuditLog).values(batch);
await logsDb.insert(requestAuditLog).values(batch);
}
return response(res, {

View File

@@ -52,7 +52,7 @@ registry.registerPath({
method: "put",
path: "/org/{orgId}/idp/oidc",
description: "Create an OIDC IdP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
tags: [OpenAPITags.OrgIdp],
request: {
params: paramsSchema,
body: {

View File

@@ -35,7 +35,7 @@ registry.registerPath({
method: "delete",
path: "/org/{orgId}/idp/{idpId}",
description: "Delete IDP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
tags: [OpenAPITags.OrgIdp],
request: {
params: paramsSchema
},

View File

@@ -50,9 +50,9 @@ async function query(idpId: number, orgId: string) {
registry.registerPath({
method: "get",
path: "/org/:orgId/idp/:idpId",
path: "/org/{orgId}/idp/{idpId}",
description: "Get an IDP by its IDP ID for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
tags: [OpenAPITags.OrgIdp],
request: {
params: paramsSchema
},

View File

@@ -67,7 +67,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/idp",
description: "List all IDP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
tags: [OpenAPITags.OrgIdp],
request: {
query: querySchema,
params: paramsSchema

View File

@@ -59,7 +59,7 @@ registry.registerPath({
method: "post",
path: "/org/{orgId}/idp/{idpId}/oidc",
description: "Update an OIDC IdP for a specific organization.",
tags: [OpenAPITags.Idp, OpenAPITags.Org],
tags: [OpenAPITags.OrgIdp],
request: {
params: paramsSchema,
body: {

View File

@@ -38,7 +38,7 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
);
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
const newlyOfflineNodes = await db
const offlineNodes = await db
.update(exitNodes)
.set({ online: false })
.where(
@@ -53,32 +53,15 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
)
.returning();
// Update the sites to offline if they have not pinged either
const exitNodeIds = newlyOfflineNodes.map(
(node) => node.exitNodeId
);
const sitesOnNode = await db
.select()
.from(sites)
.where(
and(
eq(sites.online, true),
inArray(sites.exitNodeId, exitNodeIds)
)
if (offlineNodes.length > 0) {
logger.info(
`checkRemoteExitNodeOffline: Marked ${offlineNodes.length} remoteExitNode client(s) offline due to inactivity`
);
// loop through the sites and process their lastBandwidthUpdate as an iso string and if its more than 1 minute old then mark the site offline
for (const site of sitesOnNode) {
if (!site.lastBandwidthUpdate) {
continue;
}
const lastBandwidthUpdate = new Date(site.lastBandwidthUpdate);
if (Date.now() - lastBandwidthUpdate.getTime() > 60 * 1000) {
await db
.update(sites)
.set({ online: false })
.where(eq(sites.siteId, site.siteId));
for (const offlineClient of offlineNodes) {
logger.debug(
`checkRemoteExitNodeOffline: Client ${offlineClient.exitNodeId} marked offline (lastPing: ${offlineClient.lastPing})`
);
}
}
} catch (error) {

View File

@@ -52,7 +52,7 @@ registry.registerPath({
method: "get",
path: "/maintenance/info",
description: "Get maintenance information for a resource by domain.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
query: z.object({
fullDomain: z.string()

View File

@@ -14,7 +14,9 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import {
actionAuditLog,
db,
logsDb,
newts,
roles,
roundTripMessageTracker,
@@ -29,12 +31,12 @@ import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { and, eq, inArray, or } from "drizzle-orm";
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
import { signPublicKey, getOrgCAKeys } from "#private/lib/sshCA";
import { signPublicKey, getOrgCAKeys } from "@server/lib/sshCA";
import config from "@server/lib/config";
import { sendToClient } from "#private/routers/ws";
import { ActionsEnum } from "@server/auth/actions";
const paramsSchema = z.strictObject({
orgId: z.string().nonempty()
@@ -64,6 +66,7 @@ export type SignSshKeyResponse = {
sshUsername: string;
sshHost: string;
resourceId: number;
siteId: number;
keyId: string;
validPrincipals: string[];
validAfter: string;
@@ -185,7 +188,7 @@ export async function signSshKey(
} else if (req.user?.username) {
usernameToUse = req.user.username;
// We need to clean out any spaces or special characters from the username to ensure it's valid for SSH certificates
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "");
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "-");
if (!usernameToUse) {
return next(
createHttpError(
@@ -203,6 +206,9 @@ export async function signSshKey(
);
}
// prefix with p-
usernameToUse = `p-${usernameToUse}`;
// check if we have a existing user in this org with the same
const [existingUserWithSameName] = await db
.select()
@@ -248,6 +254,16 @@ export async function signSshKey(
);
}
}
await db
.update(userOrgs)
.set({ pamUsername: usernameToUse })
.where(
and(
eq(userOrgs.orgId, orgId),
eq(userOrgs.userId, userId)
)
);
} else {
usernameToUse = userOrg.pamUsername;
}
@@ -319,7 +335,16 @@ export async function signSshKey(
);
}
// Check if the user has access to the resource (any of their roles)
if (resource.mode == "cidr") {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"SSHing is not supported for CIDR resources"
)
);
}
// Check if the user has access to the resource
const hasAccess = await canUserAccessSiteResource({
userId: userId,
resourceId: resource.siteResourceId,
@@ -444,6 +469,20 @@ export async function signSshKey(
sshHost = resource.destination;
}
await logsDb.insert(actionAuditLog).values({
timestamp: Math.floor(Date.now() / 1000),
orgId: orgId,
actorType: "user",
actor: req.user?.username ?? "",
actorId: req.user?.userId ?? "",
action: ActionsEnum.signSshKey,
metadata: JSON.stringify({
resourceId: resource.siteResourceId,
resource: resource.name,
siteId: resource.siteId,
})
});
return response<SignSshKeyResponse>(res, {
data: {
certificate: cert.certificate,
@@ -451,6 +490,7 @@ export async function signSshKey(
sshUsername: usernameToUse,
sshHost: sshHost,
resourceId: resource.siteResourceId,
siteId: resource.siteId,
keyId: cert.keyId,
validPrincipals: cert.validPrincipals,
validAfter: cert.validAfter.toISOString(),

View File

@@ -17,10 +17,13 @@ import {
startRemoteExitNodeOfflineChecker
} from "#private/routers/remoteExitNode";
import { MessageHandler } from "@server/routers/ws";
import { build } from "@server/build";
export const messageHandlers: Record<string, MessageHandler> = {
"remoteExitNode/register": handleRemoteExitNodeRegisterMessage,
"remoteExitNode/ping": handleRemoteExitNodePingMessage
};
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
if (build != "saas") {
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
}

View File

@@ -12,6 +12,7 @@
*/
import { Router, Request, Response } from "express";
import zlib from "zlib";
import { Server as HttpServer } from "http";
import { WebSocket, WebSocketServer } from "ws";
import { Socket } from "net";
@@ -24,7 +25,8 @@ import {
OlmSession,
RemoteExitNode,
RemoteExitNodeSession,
remoteExitNodes
remoteExitNodes,
sites
} from "@server/db";
import { eq } from "drizzle-orm";
import { db } from "@server/db";
@@ -57,11 +59,13 @@ const MAX_PENDING_MESSAGES = 50; // Maximum messages to queue during connection
const processMessage = async (
ws: AuthenticatedWebSocket,
data: Buffer,
isBinary: boolean,
clientId: string,
clientType: ClientType
): Promise<void> => {
try {
const message: WSMessage = JSON.parse(data.toString());
const messageBuffer = isBinary ? zlib.gunzipSync(data) : data;
const message: WSMessage = JSON.parse(messageBuffer.toString());
// logger.debug(
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
@@ -76,7 +80,7 @@ const processMessage = async (
clientId,
message.type, // Pass message type for granular limiting
100, // max requests per window
20, // max requests per message type per window
100, // max requests per message type per window
60 * 1000 // window in milliseconds
);
if (rateLimitResult.isLimited) {
@@ -163,8 +167,16 @@ const processPendingMessages = async (
);
const jobs = [];
for (const messageData of ws.pendingMessages) {
jobs.push(processMessage(ws, messageData, clientId, clientType));
for (const pending of ws.pendingMessages) {
jobs.push(
processMessage(
ws,
pending.data,
pending.isBinary,
clientId,
clientType
)
);
}
await Promise.all(jobs);
@@ -185,6 +197,12 @@ const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
// Config version tracking map (local to this node, resets on server restart)
const clientConfigVersions: Map<string, number> = new Map();
// Tracks the last Unix timestamp (seconds) at which a ping was flushed to the
// DB for a given siteId. Resets on server restart which is fine the first
// ping after startup will always write, re-establishing the online state.
const lastPingDbWrite: Map<number, number> = new Map();
const PING_DB_WRITE_INTERVAL = 45; // seconds
// Recovery tracking
let isRedisRecoveryInProgress = false;
@@ -325,7 +343,9 @@ const addClient = async (
// Check Redis first if enabled
if (redisManager.isRedisEnabled()) {
try {
const redisVersion = await redisManager.get(getConfigVersionKey(clientId));
const redisVersion = await redisManager.get(
getConfigVersionKey(clientId)
);
if (redisVersion !== null) {
configVersion = parseInt(redisVersion, 10);
// Sync to local cache
@@ -337,7 +357,10 @@ const addClient = async (
} else {
// Use local cache version and sync to Redis
configVersion = clientConfigVersions.get(clientId) || 0;
await redisManager.set(getConfigVersionKey(clientId), configVersion.toString());
await redisManager.set(
getConfigVersionKey(clientId),
configVersion.toString()
);
}
} catch (error) {
logger.error("Failed to get/set config version in Redis:", error);
@@ -432,7 +455,9 @@ const removeClient = async (
};
// Helper to get the current config version for a client
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
const getClientConfigVersion = async (
clientId: string
): Promise<number | undefined> => {
// Try Redis first if available
if (redisManager.isRedisEnabled()) {
try {
@@ -502,11 +527,26 @@ const sendToClientLocal = async (
};
const messageString = JSON.stringify(messageWithVersion);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(messageString);
}
});
if (options.compress) {
logger.debug(
`Message size before compression: ${messageString.length} bytes`
);
const compressed = zlib.gzipSync(Buffer.from(messageString, "utf8"));
logger.debug(
`Message size after compression: ${compressed.length} bytes`
);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(messageString);
}
});
}
return true;
};
@@ -532,11 +572,22 @@ const broadcastToAllExceptLocal = async (
configVersion
};
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion));
}
});
if (options.compress) {
const compressed = zlib.gzipSync(
Buffer.from(JSON.stringify(messageWithVersion), "utf8")
);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion));
}
});
}
}
}
};
@@ -762,7 +813,7 @@ const setupConnection = async (
}
// Set up message handler FIRST to prevent race condition
ws.on("message", async (data) => {
ws.on("message", async (data, isBinary) => {
if (!ws.isFullyConnected) {
// Queue message for later processing with limits
ws.pendingMessages = ws.pendingMessages || [];
@@ -777,11 +828,17 @@ const setupConnection = async (
logger.debug(
`Queueing message from ${clientType.toUpperCase()} ID: ${clientId} (connection not fully established)`
);
ws.pendingMessages.push(data as Buffer);
ws.pendingMessages.push({ data: data as Buffer, isBinary });
return;
}
await processMessage(ws, data as Buffer, clientId, clientType);
await processMessage(
ws,
data as Buffer,
isBinary,
clientId,
clientType
);
});
// Set up other event handlers before async operations
@@ -796,6 +853,35 @@ const setupConnection = async (
);
});
// Handle WebSocket protocol-level pings from older newt clients that do
// not send application-level "newt/ping" messages. Update the site's
// online state and lastPing timestamp so the offline checker treats them
// the same as modern newt clients.
if (clientType === "newt") {
const newtClient = client as Newt;
ws.on("ping", async () => {
if (!newtClient.siteId) return;
const now = Math.floor(Date.now() / 1000);
const lastWrite = lastPingDbWrite.get(newtClient.siteId) ?? 0;
if (now - lastWrite < PING_DB_WRITE_INTERVAL) return;
lastPingDbWrite.set(newtClient.siteId, now);
try {
await db
.update(sites)
.set({
online: true,
lastPing: now
})
.where(eq(sites.siteId, newtClient.siteId));
} catch (error) {
logger.error(
"Error updating newt site online state on WS ping",
{ error }
);
}
});
}
ws.on("error", (error: Error) => {
logger.error(
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,