From 7d8797840a1d9cc8d455fe858ff97d4f3b206e15 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 23 Mar 2026 17:01:34 -0700 Subject: [PATCH 01/24] Add connection log --- server/cleanup.ts | 4 +- server/db/pg/schema/privateSchema.ts | 38 ++- server/db/sqlite/schema/privateSchema.ts | 36 ++- server/private/cleanup.ts | 4 +- server/private/routers/ws/messageHandlers.ts | 4 +- .../newt/handleConnectionLogMessage.ts | 302 ++++++++++++++++++ server/routers/newt/index.ts | 1 + 7 files changed, 384 insertions(+), 5 deletions(-) create mode 100644 server/routers/newt/handleConnectionLogMessage.ts diff --git a/server/cleanup.ts b/server/cleanup.ts index 137654827..7366bb876 100644 --- a/server/cleanup.ts +++ b/server/cleanup.ts @@ -1,9 +1,11 @@ import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage"; +import { flushConnectionLogToDb } from "@server/routers/newt/handleConnectionLogMessage"; import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth"; import { cleanup as wsCleanup } from "#dynamic/routers/ws"; async function cleanup() { await flushBandwidthToDb(); + await flushConnectionLogToDb(); await flushSiteBandwidthToDb(); await wsCleanup(); @@ -14,4 +16,4 @@ export async function initCleanup() { // Handle process termination process.on("SIGTERM", () => cleanup()); process.on("SIGINT", () => cleanup()); -} \ No newline at end of file +} diff --git a/server/db/pg/schema/privateSchema.ts b/server/db/pg/schema/privateSchema.ts index c9d7cc907..8fed6462a 100644 --- a/server/db/pg/schema/privateSchema.ts +++ b/server/db/pg/schema/privateSchema.ts @@ -17,7 +17,9 @@ import { users, exitNodes, sessions, - clients + clients, + siteResources, + sites } from "./schema"; export const certificates = pgTable("certificates", { @@ -302,6 +304,39 @@ export const accessAuditLog = pgTable( ] ); +export const connectionAuditLog = pgTable( + "connectionAuditLog", + { + id: serial("id").primaryKey(), + sessionId: text("sessionId").notNull(), + siteResourceId: integer("siteResourceId").references( + () => siteResources.siteResourceId, + { onDelete: "cascade" } + ), + orgId: text("orgId").references(() => orgs.orgId, { + onDelete: "cascade" + }), + siteId: integer("siteId").references(() => sites.siteId, { + onDelete: "cascade" + }), + sourceAddr: text("sourceAddr").notNull(), + destAddr: text("destAddr").notNull(), + protocol: text("protocol").notNull(), + startedAt: integer("startedAt").notNull(), + endedAt: integer("endedAt"), + bytesTx: integer("bytesTx"), + bytesRx: integer("bytesRx") + }, + (table) => [ + index("idx_accessAuditLog_startedAt").on(table.startedAt), + index("idx_accessAuditLog_org_startedAt").on( + table.orgId, + table.startedAt + ), + index("idx_accessAuditLog_siteResourceId").on(table.siteResourceId) + ] +); + export const approvals = pgTable("approvals", { approvalId: serial("approvalId").primaryKey(), timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds @@ -357,3 +392,4 @@ export type LoginPage = InferSelectModel; export type LoginPageBranding = InferSelectModel; export type ActionAuditLog = InferSelectModel; export type AccessAuditLog = InferSelectModel; +export type ConnectionAuditLog = InferSelectModel; diff --git a/server/db/sqlite/schema/privateSchema.ts b/server/db/sqlite/schema/privateSchema.ts index 8baeb5220..ecc386ea6 100644 --- a/server/db/sqlite/schema/privateSchema.ts +++ b/server/db/sqlite/schema/privateSchema.ts @@ -6,7 +6,7 @@ import { sqliteTable, text } from "drizzle-orm/sqlite-core"; -import { clients, domains, exitNodes, orgs, sessions, users } from "./schema"; +import { clients, domains, exitNodes, orgs, sessions, siteResources, sites, users } from "./schema"; export const certificates = sqliteTable("certificates", { certId: integer("certId").primaryKey({ autoIncrement: true }), @@ -294,6 +294,39 @@ export const accessAuditLog = sqliteTable( ] ); +export const connectionAuditLog = sqliteTable( + "connectionAuditLog", + { + id: integer("id").primaryKey({ autoIncrement: true }), + sessionId: text("sessionId").notNull(), + siteResourceId: integer("siteResourceId").references( + () => siteResources.siteResourceId, + { onDelete: "cascade" } + ), + orgId: text("orgId").references(() => orgs.orgId, { + onDelete: "cascade" + }), + siteId: integer("siteId").references(() => sites.siteId, { + onDelete: "cascade" + }), + sourceAddr: text("sourceAddr").notNull(), + destAddr: text("destAddr").notNull(), + protocol: text("protocol").notNull(), + startedAt: integer("startedAt").notNull(), + endedAt: integer("endedAt"), + bytesTx: integer("bytesTx"), + bytesRx: integer("bytesRx") + }, + (table) => [ + index("idx_accessAuditLog_startedAt").on(table.startedAt), + index("idx_accessAuditLog_org_startedAt").on( + table.orgId, + table.startedAt + ), + index("idx_accessAuditLog_siteResourceId").on(table.siteResourceId) + ] +); + export const approvals = sqliteTable("approvals", { approvalId: integer("approvalId").primaryKey({ autoIncrement: true }), timestamp: integer("timestamp").notNull(), // this is EPOCH time in seconds @@ -348,3 +381,4 @@ export type LoginPage = InferSelectModel; export type LoginPageBranding = InferSelectModel; export type ActionAuditLog = InferSelectModel; export type AccessAuditLog = InferSelectModel; +export type ConnectionAuditLog = InferSelectModel; diff --git a/server/private/cleanup.ts b/server/private/cleanup.ts index 0bd9822dd..933c943ed 100644 --- a/server/private/cleanup.ts +++ b/server/private/cleanup.ts @@ -14,10 +14,12 @@ import { rateLimitService } from "#private/lib/rateLimit"; import { cleanup as wsCleanup } from "#private/routers/ws"; import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage"; +import { flushConnectionLogToDb } from "@server/routers/newt/handleConnectionLogMessage"; import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth"; async function cleanup() { await flushBandwidthToDb(); + await flushConnectionLogToDb(); await flushSiteBandwidthToDb(); await rateLimitService.cleanup(); await wsCleanup(); @@ -29,4 +31,4 @@ export async function initCleanup() { // Handle process termination process.on("SIGTERM", () => cleanup()); process.on("SIGINT", () => cleanup()); -} \ No newline at end of file +} diff --git a/server/private/routers/ws/messageHandlers.ts b/server/private/routers/ws/messageHandlers.ts index d388ce40a..9e4622e2d 100644 --- a/server/private/routers/ws/messageHandlers.ts +++ b/server/private/routers/ws/messageHandlers.ts @@ -18,10 +18,12 @@ import { } from "#private/routers/remoteExitNode"; import { MessageHandler } from "@server/routers/ws"; import { build } from "@server/build"; +import { handleConnectionLogMessage } from "@server/routers/newt"; export const messageHandlers: Record = { "remoteExitNode/register": handleRemoteExitNodeRegisterMessage, - "remoteExitNode/ping": handleRemoteExitNodePingMessage + "remoteExitNode/ping": handleRemoteExitNodePingMessage, + "newt/access-log": handleConnectionLogMessage, }; if (build != "saas") { diff --git a/server/routers/newt/handleConnectionLogMessage.ts b/server/routers/newt/handleConnectionLogMessage.ts new file mode 100644 index 000000000..458470af7 --- /dev/null +++ b/server/routers/newt/handleConnectionLogMessage.ts @@ -0,0 +1,302 @@ +import { db } from "@server/db"; +import { MessageHandler } from "@server/routers/ws"; +import { connectionAuditLog, sites, Newt } from "@server/db"; +import { eq } from "drizzle-orm"; +import logger from "@server/logger"; +import { inflate } from "zlib"; +import { promisify } from "util"; + +const zlibInflate = promisify(inflate); + +// Retry configuration for deadlock handling +const MAX_RETRIES = 3; +const BASE_DELAY_MS = 50; + +// How often to flush accumulated connection log data to the database +const FLUSH_INTERVAL_MS = 30_000; // 30 seconds + +// Maximum number of records to buffer before forcing a flush +const MAX_BUFFERED_RECORDS = 500; + +// Maximum number of records to insert in a single batch +const INSERT_BATCH_SIZE = 100; + +interface ConnectionSessionData { + sessionId: string; + resourceId: number; + sourceAddr: string; + destAddr: string; + protocol: string; + startedAt: string; // ISO 8601 timestamp + endedAt?: string; // ISO 8601 timestamp + bytesTx?: number; + bytesRx?: number; +} + +interface ConnectionLogRecord { + sessionId: string; + siteResourceId: number; + orgId: string; + siteId: number; + sourceAddr: string; + destAddr: string; + protocol: string; + startedAt: number; // epoch seconds + endedAt: number | null; + bytesTx: number | null; + bytesRx: number | null; +} + +// In-memory buffer of records waiting to be flushed +let buffer: ConnectionLogRecord[] = []; + +/** + * Check if an error is a deadlock error + */ +function isDeadlockError(error: any): boolean { + return ( + error?.code === "40P01" || + error?.cause?.code === "40P01" || + (error?.message && error.message.includes("deadlock")) + ); +} + +/** + * Execute a function with retry logic for deadlock handling + */ +async function withDeadlockRetry( + operation: () => Promise, + context: string +): Promise { + let attempt = 0; + while (true) { + try { + return await operation(); + } catch (error: any) { + if (isDeadlockError(error) && attempt < MAX_RETRIES) { + attempt++; + const baseDelay = Math.pow(2, attempt - 1) * BASE_DELAY_MS; + const jitter = Math.random() * baseDelay; + const delay = baseDelay + jitter; + logger.warn( + `Deadlock detected in ${context}, retrying attempt ${attempt}/${MAX_RETRIES} after ${delay.toFixed(0)}ms` + ); + await new Promise((resolve) => setTimeout(resolve, delay)); + continue; + } + throw error; + } + } +} + +/** + * Decompress a base64-encoded zlib-compressed string into parsed JSON. + */ +async function decompressConnectionLog( + compressed: string +): Promise { + const compressedBuffer = Buffer.from(compressed, "base64"); + const decompressed = await zlibInflate(compressedBuffer); + const jsonString = decompressed.toString("utf-8"); + const parsed = JSON.parse(jsonString); + + if (!Array.isArray(parsed)) { + throw new Error("Decompressed connection log data is not an array"); + } + + return parsed; +} + +/** + * Convert an ISO 8601 timestamp string to epoch seconds. + * Returns null if the input is falsy. + */ +function toEpochSeconds(isoString: string | undefined | null): number | null { + if (!isoString) { + return null; + } + const ms = new Date(isoString).getTime(); + if (isNaN(ms)) { + return null; + } + return Math.floor(ms / 1000); +} + +/** + * Flush all buffered connection log records to the database. + * + * Swaps out the buffer before writing so that any records added during the + * flush are captured in the new buffer rather than being lost. Entries that + * fail to write are re-queued back into the buffer so they will be retried + * on the next flush. + * + * This function is exported so that the application's graceful-shutdown + * cleanup handler can call it before the process exits. + */ +export async function flushConnectionLogToDb(): Promise { + if (buffer.length === 0) { + return; + } + + // Atomically swap out the buffer so new data keeps flowing in + const snapshot = buffer; + buffer = []; + + logger.debug( + `Flushing ${snapshot.length} connection log record(s) to the database` + ); + + // Insert in batches to avoid overly large SQL statements + for (let i = 0; i < snapshot.length; i += INSERT_BATCH_SIZE) { + const batch = snapshot.slice(i, i + INSERT_BATCH_SIZE); + + try { + await withDeadlockRetry(async () => { + await db.insert(connectionAuditLog).values(batch); + }, `flush connection log batch (${batch.length} records)`); + } catch (error) { + logger.error( + `Failed to flush connection log batch of ${batch.length} records:`, + error + ); + + // Re-queue the failed batch so it is retried on the next flush + buffer = [...batch, ...buffer]; + + // Cap buffer to prevent unbounded growth if DB is unreachable + if (buffer.length > MAX_BUFFERED_RECORDS * 5) { + const dropped = buffer.length - MAX_BUFFERED_RECORDS * 5; + buffer = buffer.slice(0, MAX_BUFFERED_RECORDS * 5); + logger.warn( + `Connection log buffer overflow, dropped ${dropped} oldest records` + ); + } + + // Stop trying further batches from this snapshot — they'll be + // picked up by the next flush via the re-queued records above + const remaining = snapshot.slice(i + INSERT_BATCH_SIZE); + if (remaining.length > 0) { + buffer = [...remaining, ...buffer]; + } + break; + } + } +} + +const flushTimer = setInterval(async () => { + try { + await flushConnectionLogToDb(); + } catch (error) { + logger.error( + "Unexpected error during periodic connection log flush:", + error + ); + } +}, FLUSH_INTERVAL_MS); + +// Calling unref() means this timer will not keep the Node.js event loop alive +// on its own — the process can still exit normally when there is no other work +// left. The graceful-shutdown path will call flushConnectionLogToDb() explicitly +// before process.exit(), so no data is lost. +flushTimer.unref(); + +export const handleConnectionLogMessage: MessageHandler = async (context) => { + const { message, client } = context; + const newt = client as Newt; + + if (!newt) { + logger.warn("Connection log received but no newt client in context"); + return; + } + + if (!newt.siteId) { + logger.warn("Connection log received but newt has no siteId"); + return; + } + + if (!message.data?.compressed) { + logger.warn("Connection log message missing compressed data"); + return; + } + + // Look up the org for this site + const [site] = await db + .select({ orgId: sites.orgId }) + .from(sites) + .where(eq(sites.siteId, newt.siteId)); + + if (!site) { + logger.warn( + `Connection log received but site ${newt.siteId} not found in database` + ); + return; + } + + const orgId = site.orgId; + + let sessions: ConnectionSessionData[]; + try { + sessions = await decompressConnectionLog(message.data.compressed); + } catch (error) { + logger.error("Failed to decompress connection log data:", error); + return; + } + + if (sessions.length === 0) { + return; + } + + // Convert to DB records and add to the buffer + for (const session of sessions) { + // Validate required fields + if ( + !session.sessionId || + !session.resourceId || + !session.sourceAddr || + !session.destAddr || + !session.protocol + ) { + logger.debug( + `Skipping connection log session with missing required fields: ${JSON.stringify(session)}` + ); + continue; + } + + const startedAt = toEpochSeconds(session.startedAt); + if (startedAt === null) { + logger.debug( + `Skipping connection log session with invalid startedAt: ${session.startedAt}` + ); + continue; + } + + buffer.push({ + sessionId: session.sessionId, + siteResourceId: session.resourceId, + orgId, + siteId: newt.siteId, + sourceAddr: session.sourceAddr, + destAddr: session.destAddr, + protocol: session.protocol, + startedAt, + endedAt: toEpochSeconds(session.endedAt), + bytesTx: session.bytesTx ?? null, + bytesRx: session.bytesRx ?? null + }); + } + + logger.debug( + `Buffered ${sessions.length} connection log session(s) from newt ${newt.newtId} (site ${newt.siteId})` + ); + + // If the buffer has grown large enough, trigger an immediate flush + if (buffer.length >= MAX_BUFFERED_RECORDS) { + // Fire and forget — errors are handled inside flushConnectionLogToDb + flushConnectionLogToDb().catch((error) => { + logger.error( + "Unexpected error during size-triggered connection log flush:", + error + ); + }); + } +}; diff --git a/server/routers/newt/index.ts b/server/routers/newt/index.ts index f31cd753b..63d1e1068 100644 --- a/server/routers/newt/index.ts +++ b/server/routers/newt/index.ts @@ -8,3 +8,4 @@ export * from "./handleNewtPingRequestMessage"; export * from "./handleApplyBlueprintMessage"; export * from "./handleNewtPingMessage"; export * from "./handleNewtDisconnectingMessage"; +export * from "./handleConnectionLogMessage"; From 0d4edcd1c79219f65f223cbdce179c3cf0cfcf1b Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 23 Mar 2026 17:23:51 -0700 Subject: [PATCH 02/24] make private --- server/cleanup.ts | 2 +- server/db/pg/schema/schema.ts | 3 + server/db/sqlite/schema/schema.ts | 3 + server/lib/cleanupLogs.ts | 18 +- server/private/cleanup.ts | 2 +- .../newt/handleConnectionLogMessage.ts | 324 ++++++++++++++++++ server/private/routers/newt/index.ts | 1 + server/private/routers/ws/messageHandlers.ts | 2 +- .../newt/handleConnectionLogMessage.ts | 299 +--------------- 9 files changed, 354 insertions(+), 300 deletions(-) create mode 100644 server/private/routers/newt/handleConnectionLogMessage.ts create mode 100644 server/private/routers/newt/index.ts diff --git a/server/cleanup.ts b/server/cleanup.ts index 7366bb876..81cc31692 100644 --- a/server/cleanup.ts +++ b/server/cleanup.ts @@ -1,5 +1,5 @@ import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage"; -import { flushConnectionLogToDb } from "@server/routers/newt/handleConnectionLogMessage"; +import { flushConnectionLogToDb } from "#dynamic/routers/newt"; import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth"; import { cleanup as wsCleanup } from "#dynamic/routers/ws"; diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index b93c21fd6..de423945d 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -55,6 +55,9 @@ export const orgs = pgTable("orgs", { settingsLogRetentionDaysAction: integer("settingsLogRetentionDaysAction") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year .notNull() .default(0), + settingsLogRetentionDaysConnection: integer("settingsLogRetentionDaysConnection") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year + .notNull() + .default(0), sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format) sshCaPublicKey: text("sshCaPublicKey"), // SSH CA public key (OpenSSH format) isBillingOrg: boolean("isBillingOrg"), diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index 188caac2b..ba02cfb76 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -47,6 +47,9 @@ export const orgs = sqliteTable("orgs", { settingsLogRetentionDaysAction: integer("settingsLogRetentionDaysAction") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year .notNull() .default(0), + settingsLogRetentionDaysConnection: integer("settingsLogRetentionDaysConnection") // where 0 = dont keep logs and -1 = keep forever and 9001 = end of the following year + .notNull() + .default(0), sshCaPrivateKey: text("sshCaPrivateKey"), // Encrypted SSH CA private key (PEM format) sshCaPublicKey: text("sshCaPublicKey"), // SSH CA public key (OpenSSH format) isBillingOrg: integer("isBillingOrg", { mode: "boolean" }), diff --git a/server/lib/cleanupLogs.ts b/server/lib/cleanupLogs.ts index 8eb4ca77f..f5b6d8b2f 100644 --- a/server/lib/cleanupLogs.ts +++ b/server/lib/cleanupLogs.ts @@ -2,6 +2,7 @@ import { db, orgs } from "@server/db"; import { cleanUpOldLogs as cleanUpOldAccessLogs } from "#dynamic/lib/logAccessAudit"; import { cleanUpOldLogs as cleanUpOldActionLogs } from "#dynamic/middlewares/logActionAudit"; import { cleanUpOldLogs as cleanUpOldRequestLogs } from "@server/routers/badger/logRequestAudit"; +import { cleanUpOldLogs as cleanUpOldConnectionLogs } from "#dynamic/routers/newt"; import { gt, or } from "drizzle-orm"; import { cleanUpOldFingerprintSnapshots } from "@server/routers/olm/fingerprintingUtils"; import { build } from "@server/build"; @@ -20,14 +21,17 @@ export function initLogCleanupInterval() { settingsLogRetentionDaysAccess: orgs.settingsLogRetentionDaysAccess, settingsLogRetentionDaysRequest: - orgs.settingsLogRetentionDaysRequest + orgs.settingsLogRetentionDaysRequest, + settingsLogRetentionDaysConnection: + orgs.settingsLogRetentionDaysConnection }) .from(orgs) .where( or( gt(orgs.settingsLogRetentionDaysAction, 0), gt(orgs.settingsLogRetentionDaysAccess, 0), - gt(orgs.settingsLogRetentionDaysRequest, 0) + gt(orgs.settingsLogRetentionDaysRequest, 0), + gt(orgs.settingsLogRetentionDaysConnection, 0) ) ); @@ -37,7 +41,8 @@ export function initLogCleanupInterval() { orgId, settingsLogRetentionDaysAction, settingsLogRetentionDaysAccess, - settingsLogRetentionDaysRequest + settingsLogRetentionDaysRequest, + settingsLogRetentionDaysConnection } = org; if (settingsLogRetentionDaysAction > 0) { @@ -60,6 +65,13 @@ export function initLogCleanupInterval() { settingsLogRetentionDaysRequest ); } + + if (settingsLogRetentionDaysConnection > 0) { + await cleanUpOldConnectionLogs( + orgId, + settingsLogRetentionDaysConnection + ); + } } await cleanUpOldFingerprintSnapshots(365); diff --git a/server/private/cleanup.ts b/server/private/cleanup.ts index 933c943ed..4b12f1b3c 100644 --- a/server/private/cleanup.ts +++ b/server/private/cleanup.ts @@ -14,7 +14,7 @@ import { rateLimitService } from "#private/lib/rateLimit"; import { cleanup as wsCleanup } from "#private/routers/ws"; import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage"; -import { flushConnectionLogToDb } from "@server/routers/newt/handleConnectionLogMessage"; +import { flushConnectionLogToDb } from "#dynamic/routers/newt"; import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth"; async function cleanup() { diff --git a/server/private/routers/newt/handleConnectionLogMessage.ts b/server/private/routers/newt/handleConnectionLogMessage.ts new file mode 100644 index 000000000..7549ab37f --- /dev/null +++ b/server/private/routers/newt/handleConnectionLogMessage.ts @@ -0,0 +1,324 @@ +import { db, logsDb } from "@server/db"; +import { MessageHandler } from "@server/routers/ws"; +import { connectionAuditLog, sites, Newt } from "@server/db"; +import { and, eq, lt } from "drizzle-orm"; +import logger from "@server/logger"; +import { inflate } from "zlib"; +import { promisify } from "util"; +import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs"; + +const zlibInflate = promisify(inflate); + +// Retry configuration for deadlock handling +const MAX_RETRIES = 3; +const BASE_DELAY_MS = 50; + +// How often to flush accumulated connection log data to the database +const FLUSH_INTERVAL_MS = 30_000; // 30 seconds + +// Maximum number of records to buffer before forcing a flush +const MAX_BUFFERED_RECORDS = 500; + +// Maximum number of records to insert in a single batch +const INSERT_BATCH_SIZE = 100; + +interface ConnectionSessionData { + sessionId: string; + resourceId: number; + sourceAddr: string; + destAddr: string; + protocol: string; + startedAt: string; // ISO 8601 timestamp + endedAt?: string; // ISO 8601 timestamp + bytesTx?: number; + bytesRx?: number; +} + +interface ConnectionLogRecord { + sessionId: string; + siteResourceId: number; + orgId: string; + siteId: number; + sourceAddr: string; + destAddr: string; + protocol: string; + startedAt: number; // epoch seconds + endedAt: number | null; + bytesTx: number | null; + bytesRx: number | null; +} + +// In-memory buffer of records waiting to be flushed +let buffer: ConnectionLogRecord[] = []; + +/** + * Check if an error is a deadlock error + */ +function isDeadlockError(error: any): boolean { + return ( + error?.code === "40P01" || + error?.cause?.code === "40P01" || + (error?.message && error.message.includes("deadlock")) + ); +} + +/** + * Execute a function with retry logic for deadlock handling + */ +async function withDeadlockRetry( + operation: () => Promise, + context: string +): Promise { + let attempt = 0; + while (true) { + try { + return await operation(); + } catch (error: any) { + if (isDeadlockError(error) && attempt < MAX_RETRIES) { + attempt++; + const baseDelay = Math.pow(2, attempt - 1) * BASE_DELAY_MS; + const jitter = Math.random() * baseDelay; + const delay = baseDelay + jitter; + logger.warn( + `Deadlock detected in ${context}, retrying attempt ${attempt}/${MAX_RETRIES} after ${delay.toFixed(0)}ms` + ); + await new Promise((resolve) => setTimeout(resolve, delay)); + continue; + } + throw error; + } + } +} + +/** + * Decompress a base64-encoded zlib-compressed string into parsed JSON. + */ +async function decompressConnectionLog( + compressed: string +): Promise { + const compressedBuffer = Buffer.from(compressed, "base64"); + const decompressed = await zlibInflate(compressedBuffer); + const jsonString = decompressed.toString("utf-8"); + const parsed = JSON.parse(jsonString); + + if (!Array.isArray(parsed)) { + throw new Error("Decompressed connection log data is not an array"); + } + + return parsed; +} + +/** + * Convert an ISO 8601 timestamp string to epoch seconds. + * Returns null if the input is falsy. + */ +function toEpochSeconds(isoString: string | undefined | null): number | null { + if (!isoString) { + return null; + } + const ms = new Date(isoString).getTime(); + if (isNaN(ms)) { + return null; + } + return Math.floor(ms / 1000); +} + +/** + * Flush all buffered connection log records to the database. + * + * Swaps out the buffer before writing so that any records added during the + * flush are captured in the new buffer rather than being lost. Entries that + * fail to write are re-queued back into the buffer so they will be retried + * on the next flush. + * + * This function is exported so that the application's graceful-shutdown + * cleanup handler can call it before the process exits. + */ +export async function flushConnectionLogToDb(): Promise { + if (buffer.length === 0) { + return; + } + + // Atomically swap out the buffer so new data keeps flowing in + const snapshot = buffer; + buffer = []; + + logger.debug( + `Flushing ${snapshot.length} connection log record(s) to the database` + ); + + // Insert in batches to avoid overly large SQL statements + for (let i = 0; i < snapshot.length; i += INSERT_BATCH_SIZE) { + const batch = snapshot.slice(i, i + INSERT_BATCH_SIZE); + + try { + await withDeadlockRetry(async () => { + await logsDb.insert(connectionAuditLog).values(batch); + }, `flush connection log batch (${batch.length} records)`); + } catch (error) { + logger.error( + `Failed to flush connection log batch of ${batch.length} records:`, + error + ); + + // Re-queue the failed batch so it is retried on the next flush + buffer = [...batch, ...buffer]; + + // Cap buffer to prevent unbounded growth if DB is unreachable + if (buffer.length > MAX_BUFFERED_RECORDS * 5) { + const dropped = buffer.length - MAX_BUFFERED_RECORDS * 5; + buffer = buffer.slice(0, MAX_BUFFERED_RECORDS * 5); + logger.warn( + `Connection log buffer overflow, dropped ${dropped} oldest records` + ); + } + + // Stop trying further batches from this snapshot — they'll be + // picked up by the next flush via the re-queued records above + const remaining = snapshot.slice(i + INSERT_BATCH_SIZE); + if (remaining.length > 0) { + buffer = [...remaining, ...buffer]; + } + break; + } + } +} + +const flushTimer = setInterval(async () => { + try { + await flushConnectionLogToDb(); + } catch (error) { + logger.error( + "Unexpected error during periodic connection log flush:", + error + ); + } +}, FLUSH_INTERVAL_MS); + +// Calling unref() means this timer will not keep the Node.js event loop alive +// on its own — the process can still exit normally when there is no other work +// left. The graceful-shutdown path will call flushConnectionLogToDb() explicitly +// before process.exit(), so no data is lost. +flushTimer.unref(); + +export async function cleanUpOldLogs(orgId: string, retentionDays: number) { + const cutoffTimestamp = calculateCutoffTimestamp(retentionDays); + + try { + await logsDb + .delete(connectionAuditLog) + .where( + and( + lt(connectionAuditLog.startedAt, cutoffTimestamp), + eq(connectionAuditLog.orgId, orgId) + ) + ); + + // logger.debug( + // `Cleaned up connection audit logs older than ${retentionDays} days` + // ); + } catch (error) { + logger.error("Error cleaning up old connection audit logs:", error); + } +} + +export const handleConnectionLogMessage: MessageHandler = async (context) => { + const { message, client } = context; + const newt = client as Newt; + + if (!newt) { + logger.warn("Connection log received but no newt client in context"); + return; + } + + if (!newt.siteId) { + logger.warn("Connection log received but newt has no siteId"); + return; + } + + if (!message.data?.compressed) { + logger.warn("Connection log message missing compressed data"); + return; + } + + // Look up the org for this site + const [site] = await db + .select({ orgId: sites.orgId }) + .from(sites) + .where(eq(sites.siteId, newt.siteId)); + + if (!site) { + logger.warn( + `Connection log received but site ${newt.siteId} not found in database` + ); + return; + } + + const orgId = site.orgId; + + let sessions: ConnectionSessionData[]; + try { + sessions = await decompressConnectionLog(message.data.compressed); + } catch (error) { + logger.error("Failed to decompress connection log data:", error); + return; + } + + if (sessions.length === 0) { + return; + } + + // Convert to DB records and add to the buffer + for (const session of sessions) { + // Validate required fields + if ( + !session.sessionId || + !session.resourceId || + !session.sourceAddr || + !session.destAddr || + !session.protocol + ) { + logger.debug( + `Skipping connection log session with missing required fields: ${JSON.stringify(session)}` + ); + continue; + } + + const startedAt = toEpochSeconds(session.startedAt); + if (startedAt === null) { + logger.debug( + `Skipping connection log session with invalid startedAt: ${session.startedAt}` + ); + continue; + } + + buffer.push({ + sessionId: session.sessionId, + siteResourceId: session.resourceId, + orgId, + siteId: newt.siteId, + sourceAddr: session.sourceAddr, + destAddr: session.destAddr, + protocol: session.protocol, + startedAt, + endedAt: toEpochSeconds(session.endedAt), + bytesTx: session.bytesTx ?? null, + bytesRx: session.bytesRx ?? null + }); + } + + logger.debug( + `Buffered ${sessions.length} connection log session(s) from newt ${newt.newtId} (site ${newt.siteId})` + ); + + // If the buffer has grown large enough, trigger an immediate flush + if (buffer.length >= MAX_BUFFERED_RECORDS) { + // Fire and forget — errors are handled inside flushConnectionLogToDb + flushConnectionLogToDb().catch((error) => { + logger.error( + "Unexpected error during size-triggered connection log flush:", + error + ); + }); + } +}; diff --git a/server/private/routers/newt/index.ts b/server/private/routers/newt/index.ts new file mode 100644 index 000000000..cc182cf7d --- /dev/null +++ b/server/private/routers/newt/index.ts @@ -0,0 +1 @@ +export * from "./handleConnectionLogMessage"; diff --git a/server/private/routers/ws/messageHandlers.ts b/server/private/routers/ws/messageHandlers.ts index 9e4622e2d..a3c9c5bdb 100644 --- a/server/private/routers/ws/messageHandlers.ts +++ b/server/private/routers/ws/messageHandlers.ts @@ -18,7 +18,7 @@ import { } from "#private/routers/remoteExitNode"; import { MessageHandler } from "@server/routers/ws"; import { build } from "@server/build"; -import { handleConnectionLogMessage } from "@server/routers/newt"; +import { handleConnectionLogMessage } from "#dynamic/routers/newt"; export const messageHandlers: Record = { "remoteExitNode/register": handleRemoteExitNodeRegisterMessage, diff --git a/server/routers/newt/handleConnectionLogMessage.ts b/server/routers/newt/handleConnectionLogMessage.ts index 458470af7..ca1b129d2 100644 --- a/server/routers/newt/handleConnectionLogMessage.ts +++ b/server/routers/newt/handleConnectionLogMessage.ts @@ -1,302 +1,13 @@ -import { db } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; -import { connectionAuditLog, sites, Newt } from "@server/db"; -import { eq } from "drizzle-orm"; -import logger from "@server/logger"; -import { inflate } from "zlib"; -import { promisify } from "util"; -const zlibInflate = promisify(inflate); - -// Retry configuration for deadlock handling -const MAX_RETRIES = 3; -const BASE_DELAY_MS = 50; - -// How often to flush accumulated connection log data to the database -const FLUSH_INTERVAL_MS = 30_000; // 30 seconds - -// Maximum number of records to buffer before forcing a flush -const MAX_BUFFERED_RECORDS = 500; - -// Maximum number of records to insert in a single batch -const INSERT_BATCH_SIZE = 100; - -interface ConnectionSessionData { - sessionId: string; - resourceId: number; - sourceAddr: string; - destAddr: string; - protocol: string; - startedAt: string; // ISO 8601 timestamp - endedAt?: string; // ISO 8601 timestamp - bytesTx?: number; - bytesRx?: number; -} - -interface ConnectionLogRecord { - sessionId: string; - siteResourceId: number; - orgId: string; - siteId: number; - sourceAddr: string; - destAddr: string; - protocol: string; - startedAt: number; // epoch seconds - endedAt: number | null; - bytesTx: number | null; - bytesRx: number | null; -} - -// In-memory buffer of records waiting to be flushed -let buffer: ConnectionLogRecord[] = []; - -/** - * Check if an error is a deadlock error - */ -function isDeadlockError(error: any): boolean { - return ( - error?.code === "40P01" || - error?.cause?.code === "40P01" || - (error?.message && error.message.includes("deadlock")) - ); -} - -/** - * Execute a function with retry logic for deadlock handling - */ -async function withDeadlockRetry( - operation: () => Promise, - context: string -): Promise { - let attempt = 0; - while (true) { - try { - return await operation(); - } catch (error: any) { - if (isDeadlockError(error) && attempt < MAX_RETRIES) { - attempt++; - const baseDelay = Math.pow(2, attempt - 1) * BASE_DELAY_MS; - const jitter = Math.random() * baseDelay; - const delay = baseDelay + jitter; - logger.warn( - `Deadlock detected in ${context}, retrying attempt ${attempt}/${MAX_RETRIES} after ${delay.toFixed(0)}ms` - ); - await new Promise((resolve) => setTimeout(resolve, delay)); - continue; - } - throw error; - } - } -} - -/** - * Decompress a base64-encoded zlib-compressed string into parsed JSON. - */ -async function decompressConnectionLog( - compressed: string -): Promise { - const compressedBuffer = Buffer.from(compressed, "base64"); - const decompressed = await zlibInflate(compressedBuffer); - const jsonString = decompressed.toString("utf-8"); - const parsed = JSON.parse(jsonString); - - if (!Array.isArray(parsed)) { - throw new Error("Decompressed connection log data is not an array"); - } - - return parsed; -} - -/** - * Convert an ISO 8601 timestamp string to epoch seconds. - * Returns null if the input is falsy. - */ -function toEpochSeconds(isoString: string | undefined | null): number | null { - if (!isoString) { - return null; - } - const ms = new Date(isoString).getTime(); - if (isNaN(ms)) { - return null; - } - return Math.floor(ms / 1000); -} - -/** - * Flush all buffered connection log records to the database. - * - * Swaps out the buffer before writing so that any records added during the - * flush are captured in the new buffer rather than being lost. Entries that - * fail to write are re-queued back into the buffer so they will be retried - * on the next flush. - * - * This function is exported so that the application's graceful-shutdown - * cleanup handler can call it before the process exits. - */ export async function flushConnectionLogToDb(): Promise { - if (buffer.length === 0) { - return; - } - - // Atomically swap out the buffer so new data keeps flowing in - const snapshot = buffer; - buffer = []; - - logger.debug( - `Flushing ${snapshot.length} connection log record(s) to the database` - ); - - // Insert in batches to avoid overly large SQL statements - for (let i = 0; i < snapshot.length; i += INSERT_BATCH_SIZE) { - const batch = snapshot.slice(i, i + INSERT_BATCH_SIZE); - - try { - await withDeadlockRetry(async () => { - await db.insert(connectionAuditLog).values(batch); - }, `flush connection log batch (${batch.length} records)`); - } catch (error) { - logger.error( - `Failed to flush connection log batch of ${batch.length} records:`, - error - ); - - // Re-queue the failed batch so it is retried on the next flush - buffer = [...batch, ...buffer]; - - // Cap buffer to prevent unbounded growth if DB is unreachable - if (buffer.length > MAX_BUFFERED_RECORDS * 5) { - const dropped = buffer.length - MAX_BUFFERED_RECORDS * 5; - buffer = buffer.slice(0, MAX_BUFFERED_RECORDS * 5); - logger.warn( - `Connection log buffer overflow, dropped ${dropped} oldest records` - ); - } - - // Stop trying further batches from this snapshot — they'll be - // picked up by the next flush via the re-queued records above - const remaining = snapshot.slice(i + INSERT_BATCH_SIZE); - if (remaining.length > 0) { - buffer = [...remaining, ...buffer]; - } - break; - } - } + return; } -const flushTimer = setInterval(async () => { - try { - await flushConnectionLogToDb(); - } catch (error) { - logger.error( - "Unexpected error during periodic connection log flush:", - error - ); - } -}, FLUSH_INTERVAL_MS); - -// Calling unref() means this timer will not keep the Node.js event loop alive -// on its own — the process can still exit normally when there is no other work -// left. The graceful-shutdown path will call flushConnectionLogToDb() explicitly -// before process.exit(), so no data is lost. -flushTimer.unref(); +export async function cleanUpOldLogs(orgId: string, retentionDays: number) { + return; +} export const handleConnectionLogMessage: MessageHandler = async (context) => { - const { message, client } = context; - const newt = client as Newt; - - if (!newt) { - logger.warn("Connection log received but no newt client in context"); - return; - } - - if (!newt.siteId) { - logger.warn("Connection log received but newt has no siteId"); - return; - } - - if (!message.data?.compressed) { - logger.warn("Connection log message missing compressed data"); - return; - } - - // Look up the org for this site - const [site] = await db - .select({ orgId: sites.orgId }) - .from(sites) - .where(eq(sites.siteId, newt.siteId)); - - if (!site) { - logger.warn( - `Connection log received but site ${newt.siteId} not found in database` - ); - return; - } - - const orgId = site.orgId; - - let sessions: ConnectionSessionData[]; - try { - sessions = await decompressConnectionLog(message.data.compressed); - } catch (error) { - logger.error("Failed to decompress connection log data:", error); - return; - } - - if (sessions.length === 0) { - return; - } - - // Convert to DB records and add to the buffer - for (const session of sessions) { - // Validate required fields - if ( - !session.sessionId || - !session.resourceId || - !session.sourceAddr || - !session.destAddr || - !session.protocol - ) { - logger.debug( - `Skipping connection log session with missing required fields: ${JSON.stringify(session)}` - ); - continue; - } - - const startedAt = toEpochSeconds(session.startedAt); - if (startedAt === null) { - logger.debug( - `Skipping connection log session with invalid startedAt: ${session.startedAt}` - ); - continue; - } - - buffer.push({ - sessionId: session.sessionId, - siteResourceId: session.resourceId, - orgId, - siteId: newt.siteId, - sourceAddr: session.sourceAddr, - destAddr: session.destAddr, - protocol: session.protocol, - startedAt, - endedAt: toEpochSeconds(session.endedAt), - bytesTx: session.bytesTx ?? null, - bytesRx: session.bytesRx ?? null - }); - } - - logger.debug( - `Buffered ${sessions.length} connection log session(s) from newt ${newt.newtId} (site ${newt.siteId})` - ); - - // If the buffer has grown large enough, trigger an immediate flush - if (buffer.length >= MAX_BUFFERED_RECORDS) { - // Fire and forget — errors are handled inside flushConnectionLogToDb - flushConnectionLogToDb().catch((error) => { - logger.error( - "Unexpected error during size-triggered connection log flush:", - error - ); - }); - } + return; }; From fe40ea58c14a5eae05efdbdd64f6460d12566b4d Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 23 Mar 2026 20:05:54 -0700 Subject: [PATCH 03/24] Source client info into schema --- server/db/pg/schema/privateSchema.ts | 6 ++ server/db/sqlite/schema/privateSchema.ts | 6 ++ .../newt/handleConnectionLogMessage.ts | 64 ++++++++++++++++++- 3 files changed, 73 insertions(+), 3 deletions(-) diff --git a/server/db/pg/schema/privateSchema.ts b/server/db/pg/schema/privateSchema.ts index 8fed6462a..8d4e663df 100644 --- a/server/db/pg/schema/privateSchema.ts +++ b/server/db/pg/schema/privateSchema.ts @@ -319,6 +319,12 @@ export const connectionAuditLog = pgTable( siteId: integer("siteId").references(() => sites.siteId, { onDelete: "cascade" }), + clientId: integer("clientId").references(() => clients.clientId, { + onDelete: "cascade" + }), + userId: text("userId").references(() => users.userId, { + onDelete: "cascade" + }), sourceAddr: text("sourceAddr").notNull(), destAddr: text("destAddr").notNull(), protocol: text("protocol").notNull(), diff --git a/server/db/sqlite/schema/privateSchema.ts b/server/db/sqlite/schema/privateSchema.ts index ecc386ea6..f58b5cd18 100644 --- a/server/db/sqlite/schema/privateSchema.ts +++ b/server/db/sqlite/schema/privateSchema.ts @@ -309,6 +309,12 @@ export const connectionAuditLog = sqliteTable( siteId: integer("siteId").references(() => sites.siteId, { onDelete: "cascade" }), + clientId: integer("clientId").references(() => clients.clientId, { + onDelete: "cascade" + }), + userId: text("userId").references(() => users.userId, { + onDelete: "cascade" + }), sourceAddr: text("sourceAddr").notNull(), destAddr: text("destAddr").notNull(), protocol: text("protocol").notNull(), diff --git a/server/private/routers/newt/handleConnectionLogMessage.ts b/server/private/routers/newt/handleConnectionLogMessage.ts index 7549ab37f..164c14488 100644 --- a/server/private/routers/newt/handleConnectionLogMessage.ts +++ b/server/private/routers/newt/handleConnectionLogMessage.ts @@ -1,7 +1,7 @@ import { db, logsDb } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; -import { connectionAuditLog, sites, Newt } from "@server/db"; -import { and, eq, lt } from "drizzle-orm"; +import { connectionAuditLog, sites, Newt, clients, orgs } from "@server/db"; +import { and, eq, lt, inArray } from "drizzle-orm"; import logger from "@server/logger"; import { inflate } from "zlib"; import { promisify } from "util"; @@ -39,6 +39,8 @@ interface ConnectionLogRecord { siteResourceId: number; orgId: string; siteId: number; + clientId: number | null; + userId: string | null; sourceAddr: string; destAddr: string; protocol: string; @@ -243,8 +245,9 @@ export const handleConnectionLogMessage: MessageHandler = async (context) => { // Look up the org for this site const [site] = await db - .select({ orgId: sites.orgId }) + .select({ orgId: sites.orgId, orgSubnet: orgs.subnet }) .from(sites) + .innerJoin(orgs, eq(sites.orgId, orgs.orgId)) .where(eq(sites.siteId, newt.siteId)); if (!site) { @@ -256,6 +259,12 @@ export const handleConnectionLogMessage: MessageHandler = async (context) => { const orgId = site.orgId; + // Extract the CIDR suffix (e.g. "/16") from the org subnet so we can + // reconstruct the exact subnet string stored on each client record. + const cidrSuffix = site.orgSubnet?.includes("/") + ? site.orgSubnet.substring(site.orgSubnet.indexOf("/")) + : null; + let sessions: ConnectionSessionData[]; try { sessions = await decompressConnectionLog(message.data.compressed); @@ -268,6 +277,48 @@ export const handleConnectionLogMessage: MessageHandler = async (context) => { return; } + // Build a map from sourceAddr → { clientId, userId } by querying clients + // whose subnet field matches exactly. Client subnets are stored with the + // org's CIDR suffix (e.g. "100.90.128.5/16"), so we reconstruct that from + // each unique sourceAddr + the org's CIDR suffix and do a targeted IN query. + const ipToClient = new Map(); + + if (cidrSuffix) { + // Collect unique source addresses so we only query for what we need + const uniqueSourceAddrs = new Set(); + for (const session of sessions) { + if (session.sourceAddr) { + uniqueSourceAddrs.add(session.sourceAddr); + } + } + + if (uniqueSourceAddrs.size > 0) { + // Construct the exact subnet strings as stored in the DB + const subnetQueries = Array.from(uniqueSourceAddrs).map( + (addr) => `${addr}${cidrSuffix}` + ); + + const matchedClients = await db + .select({ + clientId: clients.clientId, + userId: clients.userId, + subnet: clients.subnet + }) + .from(clients) + .where( + and( + eq(clients.orgId, orgId), + inArray(clients.subnet, subnetQueries) + ) + ); + + for (const c of matchedClients) { + const ip = c.subnet.split("/")[0]; + ipToClient.set(ip, { clientId: c.clientId, userId: c.userId }); + } + } + } + // Convert to DB records and add to the buffer for (const session of sessions) { // Validate required fields @@ -292,11 +343,18 @@ export const handleConnectionLogMessage: MessageHandler = async (context) => { continue; } + // Match the source address to a client. The sourceAddr is the + // client's IP on the WireGuard network, which corresponds to the IP + // portion of the client's subnet CIDR (e.g. "100.90.128.5/24"). + const clientInfo = ipToClient.get(session.sourceAddr) ?? null; + buffer.push({ sessionId: session.sessionId, siteResourceId: session.resourceId, orgId, siteId: newt.siteId, + clientId: clientInfo?.clientId ?? null, + userId: clientInfo?.userId ?? null, sourceAddr: session.sourceAddr, destAddr: session.destAddr, protocol: session.protocol, From 6471571bc66798a97cfd3a87fe54e509f8a81822 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 23 Mar 2026 20:18:03 -0700 Subject: [PATCH 04/24] Add ui for connection logs --- messages/en-US.json | 6 + server/lib/billing/tierMatrix.ts | 2 + .../auditLogs/exportConnectionAuditLog.ts | 99 +++ server/private/routers/auditLogs/index.ts | 2 + .../auditLogs/queryConnectionAuditLog.ts | 378 +++++++++++ server/private/routers/external.ts | 19 + server/private/routers/integration.ts | 19 + server/routers/auditLogs/types.ts | 31 + .../[orgId]/settings/logs/connection/page.tsx | 630 ++++++++++++++++++ src/app/navigation.tsx | 6 + 10 files changed, 1192 insertions(+) create mode 100644 server/private/routers/auditLogs/exportConnectionAuditLog.ts create mode 100644 server/private/routers/auditLogs/queryConnectionAuditLog.ts create mode 100644 src/app/[orgId]/settings/logs/connection/page.tsx diff --git a/messages/en-US.json b/messages/en-US.json index 895ee1332..3be427ee0 100644 --- a/messages/en-US.json +++ b/messages/en-US.json @@ -2345,6 +2345,12 @@ "logRetentionEndOfFollowingYear": "End of following year", "actionLogsDescription": "View a history of actions performed in this organization", "accessLogsDescription": "View access auth requests for resources in this organization", + "connectionLogs": "Connection Logs", + "connectionLogsDescription": "View connection logs for tunnels in this organization", + "sidebarLogsConnection": "Connection Logs", + "sourceAddress": "Source Address", + "destinationAddress": "Destination Address", + "duration": "Duration", "licenseRequiredToUse": "An Enterprise Edition license or Pangolin Cloud is required to use this feature. Book a demo or POC trial.", "ossEnterpriseEditionRequired": "The Enterprise Edition is required to use this feature. This feature is also available in Pangolin Cloud. Book a demo or POC trial.", "certResolver": "Certificate Resolver", diff --git a/server/lib/billing/tierMatrix.ts b/server/lib/billing/tierMatrix.ts index c08bcea71..f8a0cd2f5 100644 --- a/server/lib/billing/tierMatrix.ts +++ b/server/lib/billing/tierMatrix.ts @@ -8,6 +8,7 @@ export enum TierFeature { LogExport = "logExport", AccessLogs = "accessLogs", // set the retention period to none on downgrade ActionLogs = "actionLogs", // set the retention period to none on downgrade + ConnectionLogs = "connectionLogs", RotateCredentials = "rotateCredentials", MaintencePage = "maintencePage", // handle downgrade DevicePosture = "devicePosture", @@ -26,6 +27,7 @@ export const tierMatrix: Record = { [TierFeature.LogExport]: ["tier3", "enterprise"], [TierFeature.AccessLogs]: ["tier2", "tier3", "enterprise"], [TierFeature.ActionLogs]: ["tier2", "tier3", "enterprise"], + [TierFeature.ConnectionLogs]: ["tier2", "tier3", "enterprise"], [TierFeature.RotateCredentials]: ["tier1", "tier2", "tier3", "enterprise"], [TierFeature.MaintencePage]: ["tier1", "tier2", "tier3", "enterprise"], [TierFeature.DevicePosture]: ["tier2", "tier3", "enterprise"], diff --git a/server/private/routers/auditLogs/exportConnectionAuditLog.ts b/server/private/routers/auditLogs/exportConnectionAuditLog.ts new file mode 100644 index 000000000..9349528ad --- /dev/null +++ b/server/private/routers/auditLogs/exportConnectionAuditLog.ts @@ -0,0 +1,99 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + +import { registry } from "@server/openApi"; +import { NextFunction } from "express"; +import { Request, Response } from "express"; +import { OpenAPITags } from "@server/openApi"; +import createHttpError from "http-errors"; +import HttpCode from "@server/types/HttpCode"; +import { fromError } from "zod-validation-error"; +import logger from "@server/logger"; +import { + queryConnectionAuditLogsParams, + queryConnectionAuditLogsQuery, + queryConnection, + countConnectionQuery +} from "./queryConnectionAuditLog"; +import { generateCSV } from "@server/routers/auditLogs/generateCSV"; +import { MAX_EXPORT_LIMIT } from "@server/routers/auditLogs"; + +registry.registerPath({ + method: "get", + path: "/org/{orgId}/logs/connection/export", + description: "Export the connection audit log for an organization as CSV", + tags: [OpenAPITags.Logs], + request: { + query: queryConnectionAuditLogsQuery, + params: queryConnectionAuditLogsParams + }, + responses: {} +}); + +export async function exportConnectionAuditLogs( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedQuery = queryConnectionAuditLogsQuery.safeParse(req.query); + if (!parsedQuery.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedQuery.error) + ) + ); + } + + const parsedParams = queryConnectionAuditLogsParams.safeParse(req.params); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error) + ) + ); + } + + const data = { ...parsedQuery.data, ...parsedParams.data }; + const [{ count }] = await countConnectionQuery(data); + if (count > MAX_EXPORT_LIMIT) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + `Export limit exceeded. Your selection contains ${count} rows, but the maximum is ${MAX_EXPORT_LIMIT} rows. Please select a shorter time range to reduce the data.` + ) + ); + } + + const baseQuery = queryConnection(data); + + const log = await baseQuery.limit(data.limit).offset(data.offset); + + const csvData = generateCSV(log); + + res.setHeader("Content-Type", "text/csv"); + res.setHeader( + "Content-Disposition", + `attachment; filename="connection-audit-logs-${data.orgId}-${Date.now()}.csv"` + ); + + return res.send(csvData); + } catch (error) { + logger.error(error); + return next( + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") + ); + } +} \ No newline at end of file diff --git a/server/private/routers/auditLogs/index.ts b/server/private/routers/auditLogs/index.ts index e1849a617..122455fea 100644 --- a/server/private/routers/auditLogs/index.ts +++ b/server/private/routers/auditLogs/index.ts @@ -15,3 +15,5 @@ export * from "./queryActionAuditLog"; export * from "./exportActionAuditLog"; export * from "./queryAccessAuditLog"; export * from "./exportAccessAuditLog"; +export * from "./queryConnectionAuditLog"; +export * from "./exportConnectionAuditLog"; diff --git a/server/private/routers/auditLogs/queryConnectionAuditLog.ts b/server/private/routers/auditLogs/queryConnectionAuditLog.ts new file mode 100644 index 000000000..f321444cd --- /dev/null +++ b/server/private/routers/auditLogs/queryConnectionAuditLog.ts @@ -0,0 +1,378 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + +import { + connectionAuditLog, + logsDb, + siteResources, + sites, + clients, + primaryDb +} from "@server/db"; +import { registry } from "@server/openApi"; +import { NextFunction } from "express"; +import { Request, Response } from "express"; +import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm"; +import { OpenAPITags } from "@server/openApi"; +import { z } from "zod"; +import createHttpError from "http-errors"; +import HttpCode from "@server/types/HttpCode"; +import { fromError } from "zod-validation-error"; +import { QueryConnectionAuditLogResponse } from "@server/routers/auditLogs/types"; +import response from "@server/lib/response"; +import logger from "@server/logger"; +import { getSevenDaysAgo } from "@app/lib/getSevenDaysAgo"; + +export const queryConnectionAuditLogsQuery = z.object({ + // iso string just validate its a parseable date + timeStart: z + .string() + .refine((val) => !isNaN(Date.parse(val)), { + error: "timeStart must be a valid ISO date string" + }) + .transform((val) => Math.floor(new Date(val).getTime() / 1000)) + .prefault(() => getSevenDaysAgo().toISOString()) + .openapi({ + type: "string", + format: "date-time", + description: + "Start time as ISO date string (defaults to 7 days ago)" + }), + timeEnd: z + .string() + .refine((val) => !isNaN(Date.parse(val)), { + error: "timeEnd must be a valid ISO date string" + }) + .transform((val) => Math.floor(new Date(val).getTime() / 1000)) + .optional() + .prefault(() => new Date().toISOString()) + .openapi({ + type: "string", + format: "date-time", + description: + "End time as ISO date string (defaults to current time)" + }), + protocol: z.string().optional(), + sourceAddr: z.string().optional(), + destAddr: z.string().optional(), + clientId: z + .string() + .optional() + .transform(Number) + .pipe(z.int().positive()) + .optional(), + siteId: z + .string() + .optional() + .transform(Number) + .pipe(z.int().positive()) + .optional(), + siteResourceId: z + .string() + .optional() + .transform(Number) + .pipe(z.int().positive()) + .optional(), + userId: z.string().optional(), + limit: z + .string() + .optional() + .default("1000") + .transform(Number) + .pipe(z.int().positive()), + offset: z + .string() + .optional() + .default("0") + .transform(Number) + .pipe(z.int().nonnegative()) +}); + +export const queryConnectionAuditLogsParams = z.object({ + orgId: z.string() +}); + +export const queryConnectionAuditLogsCombined = + queryConnectionAuditLogsQuery.merge(queryConnectionAuditLogsParams); +type Q = z.infer; + +function getWhere(data: Q) { + return and( + gt(connectionAuditLog.startedAt, data.timeStart), + lt(connectionAuditLog.startedAt, data.timeEnd), + eq(connectionAuditLog.orgId, data.orgId), + data.protocol + ? eq(connectionAuditLog.protocol, data.protocol) + : undefined, + data.sourceAddr + ? eq(connectionAuditLog.sourceAddr, data.sourceAddr) + : undefined, + data.destAddr + ? eq(connectionAuditLog.destAddr, data.destAddr) + : undefined, + data.clientId + ? eq(connectionAuditLog.clientId, data.clientId) + : undefined, + data.siteId + ? eq(connectionAuditLog.siteId, data.siteId) + : undefined, + data.siteResourceId + ? eq(connectionAuditLog.siteResourceId, data.siteResourceId) + : undefined, + data.userId + ? eq(connectionAuditLog.userId, data.userId) + : undefined + ); +} + +export function queryConnection(data: Q) { + return logsDb + .select({ + sessionId: connectionAuditLog.sessionId, + siteResourceId: connectionAuditLog.siteResourceId, + orgId: connectionAuditLog.orgId, + siteId: connectionAuditLog.siteId, + clientId: connectionAuditLog.clientId, + userId: connectionAuditLog.userId, + sourceAddr: connectionAuditLog.sourceAddr, + destAddr: connectionAuditLog.destAddr, + protocol: connectionAuditLog.protocol, + startedAt: connectionAuditLog.startedAt, + endedAt: connectionAuditLog.endedAt, + bytesTx: connectionAuditLog.bytesTx, + bytesRx: connectionAuditLog.bytesRx + }) + .from(connectionAuditLog) + .where(getWhere(data)) + .orderBy( + desc(connectionAuditLog.startedAt), + desc(connectionAuditLog.id) + ); +} + +export function countConnectionQuery(data: Q) { + const countQuery = logsDb + .select({ count: count() }) + .from(connectionAuditLog) + .where(getWhere(data)); + return countQuery; +} + +async function enrichWithDetails( + logs: Awaited> +) { + // Collect unique IDs from logs + const siteResourceIds = [ + ...new Set( + logs + .map((log) => log.siteResourceId) + .filter((id): id is number => id !== null && id !== undefined) + ) + ]; + const siteIds = [ + ...new Set( + logs + .map((log) => log.siteId) + .filter((id): id is number => id !== null && id !== undefined) + ) + ]; + const clientIds = [ + ...new Set( + logs + .map((log) => log.clientId) + .filter((id): id is number => id !== null && id !== undefined) + ) + ]; + + // Fetch resource details from main database + const resourceMap = new Map< + number, + { name: string; niceId: string } + >(); + if (siteResourceIds.length > 0) { + const resourceDetails = await primaryDb + .select({ + siteResourceId: siteResources.siteResourceId, + name: siteResources.name, + niceId: siteResources.niceId + }) + .from(siteResources) + .where(inArray(siteResources.siteResourceId, siteResourceIds)); + + for (const r of resourceDetails) { + resourceMap.set(r.siteResourceId, { + name: r.name, + niceId: r.niceId + }); + } + } + + // Fetch site details from main database + const siteMap = new Map(); + if (siteIds.length > 0) { + const siteDetails = await primaryDb + .select({ + siteId: sites.siteId, + name: sites.name, + niceId: sites.niceId + }) + .from(sites) + .where(inArray(sites.siteId, siteIds)); + + for (const s of siteDetails) { + siteMap.set(s.siteId, { name: s.name, niceId: s.niceId }); + } + } + + // Fetch client details from main database + const clientMap = new Map(); + if (clientIds.length > 0) { + const clientDetails = await primaryDb + .select({ + clientId: clients.clientId, + name: clients.name + }) + .from(clients) + .where(inArray(clients.clientId, clientIds)); + + for (const c of clientDetails) { + clientMap.set(c.clientId, { name: c.name }); + } + } + + // Enrich logs with details + return logs.map((log) => ({ + ...log, + resourceName: log.siteResourceId + ? resourceMap.get(log.siteResourceId)?.name ?? null + : null, + resourceNiceId: log.siteResourceId + ? resourceMap.get(log.siteResourceId)?.niceId ?? null + : null, + siteName: log.siteId + ? siteMap.get(log.siteId)?.name ?? null + : null, + siteNiceId: log.siteId + ? siteMap.get(log.siteId)?.niceId ?? null + : null, + clientName: log.clientId + ? clientMap.get(log.clientId)?.name ?? null + : null + })); +} + +async function queryUniqueFilterAttributes( + timeStart: number, + timeEnd: number, + orgId: string +) { + const baseConditions = and( + gt(connectionAuditLog.startedAt, timeStart), + lt(connectionAuditLog.startedAt, timeEnd), + eq(connectionAuditLog.orgId, orgId) + ); + + // Get unique protocols + const uniqueProtocols = await logsDb + .selectDistinct({ + protocol: connectionAuditLog.protocol + }) + .from(connectionAuditLog) + .where(baseConditions); + + return { + protocols: uniqueProtocols + .map((row) => row.protocol) + .filter((protocol): protocol is string => protocol !== null) + }; +} + +registry.registerPath({ + method: "get", + path: "/org/{orgId}/logs/connection", + description: "Query the connection audit log for an organization", + tags: [OpenAPITags.Logs], + request: { + query: queryConnectionAuditLogsQuery, + params: queryConnectionAuditLogsParams + }, + responses: {} +}); + +export async function queryConnectionAuditLogs( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedQuery = queryConnectionAuditLogsQuery.safeParse(req.query); + if (!parsedQuery.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedQuery.error) + ) + ); + } + const parsedParams = queryConnectionAuditLogsParams.safeParse( + req.params + ); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error) + ) + ); + } + + const data = { ...parsedQuery.data, ...parsedParams.data }; + + const baseQuery = queryConnection(data); + + const logsRaw = await baseQuery.limit(data.limit).offset(data.offset); + + // Enrich with resource, site, and client details + const log = await enrichWithDetails(logsRaw); + + const totalCountResult = await countConnectionQuery(data); + const totalCount = totalCountResult[0].count; + + const filterAttributes = await queryUniqueFilterAttributes( + data.timeStart, + data.timeEnd, + data.orgId + ); + + return response(res, { + data: { + log: log, + pagination: { + total: totalCount, + limit: data.limit, + offset: data.offset + }, + filterAttributes + }, + success: true, + error: false, + message: "Connection audit logs retrieved successfully", + status: HttpCode.OK + }); + } catch (error) { + logger.error(error); + return next( + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") + ); + } +} \ No newline at end of file diff --git a/server/private/routers/external.ts b/server/private/routers/external.ts index df8ea8cbb..f06ad4517 100644 --- a/server/private/routers/external.ts +++ b/server/private/routers/external.ts @@ -478,6 +478,25 @@ authenticated.get( logs.exportAccessAuditLogs ); +authenticated.get( + "/org/:orgId/logs/connection", + verifyValidLicense, + verifyValidSubscription(tierMatrix.connectionLogs), + verifyOrgAccess, + verifyUserHasAction(ActionsEnum.exportLogs), + logs.queryConnectionAuditLogs +); + +authenticated.get( + "/org/:orgId/logs/connection/export", + verifyValidLicense, + verifyValidSubscription(tierMatrix.logExport), + verifyOrgAccess, + verifyUserHasAction(ActionsEnum.exportLogs), + logActionAudit(ActionsEnum.exportLogs), + logs.exportConnectionAuditLogs +); + authenticated.post( "/re-key/:clientId/regenerate-client-secret", verifyClientAccess, // this is first to set the org id diff --git a/server/private/routers/integration.ts b/server/private/routers/integration.ts index 97b1adade..c17835025 100644 --- a/server/private/routers/integration.ts +++ b/server/private/routers/integration.ts @@ -91,6 +91,25 @@ authenticated.get( logs.exportAccessAuditLogs ); +authenticated.get( + "/org/:orgId/logs/connection", + verifyValidLicense, + verifyValidSubscription(tierMatrix.connectionLogs), + verifyApiKeyOrgAccess, + verifyApiKeyHasAction(ActionsEnum.exportLogs), + logs.queryConnectionAuditLogs +); + +authenticated.get( + "/org/:orgId/logs/connection/export", + verifyValidLicense, + verifyValidSubscription(tierMatrix.logExport), + verifyApiKeyOrgAccess, + verifyApiKeyHasAction(ActionsEnum.exportLogs), + logActionAudit(ActionsEnum.exportLogs), + logs.exportConnectionAuditLogs +); + authenticated.put( "/org/:orgId/idp/oidc", verifyValidLicense, diff --git a/server/routers/auditLogs/types.ts b/server/routers/auditLogs/types.ts index 474aa9261..20e11e17b 100644 --- a/server/routers/auditLogs/types.ts +++ b/server/routers/auditLogs/types.ts @@ -91,3 +91,34 @@ export type QueryAccessAuditLogResponse = { locations: string[]; }; }; + +export type QueryConnectionAuditLogResponse = { + log: { + sessionId: string; + siteResourceId: number | null; + orgId: string | null; + siteId: number | null; + clientId: number | null; + userId: string | null; + sourceAddr: string; + destAddr: string; + protocol: string; + startedAt: number; + endedAt: number | null; + bytesTx: number | null; + bytesRx: number | null; + resourceName: string | null; + resourceNiceId: string | null; + siteName: string | null; + siteNiceId: string | null; + clientName: string | null; + }[]; + pagination: { + total: number; + limit: number; + offset: number; + }; + filterAttributes: { + protocols: string[]; + }; +}; diff --git a/src/app/[orgId]/settings/logs/connection/page.tsx b/src/app/[orgId]/settings/logs/connection/page.tsx new file mode 100644 index 000000000..737b1efd7 --- /dev/null +++ b/src/app/[orgId]/settings/logs/connection/page.tsx @@ -0,0 +1,630 @@ +"use client"; +import { ColumnFilter } from "@app/components/ColumnFilter"; +import { DateTimeValue } from "@app/components/DateTimePicker"; +import { LogDataTable } from "@app/components/LogDataTable"; +import { PaidFeaturesAlert } from "@app/components/PaidFeaturesAlert"; +import SettingsSectionTitle from "@app/components/SettingsSectionTitle"; +import { useEnvContext } from "@app/hooks/useEnvContext"; +import { usePaidStatus } from "@app/hooks/usePaidStatus"; +import { useStoredPageSize } from "@app/hooks/useStoredPageSize"; +import { toast } from "@app/hooks/useToast"; +import { createApiClient } from "@app/lib/api"; +import { getSevenDaysAgo } from "@app/lib/getSevenDaysAgo"; +import { build } from "@server/build"; +import { tierMatrix } from "@server/lib/billing/tierMatrix"; +import { ColumnDef } from "@tanstack/react-table"; +import axios from "axios"; +import { Cable, Monitor, Server } from "lucide-react"; +import { useTranslations } from "next-intl"; +import { useParams, useRouter, useSearchParams } from "next/navigation"; +import { useEffect, useState, useTransition } from "react"; + +function formatBytes(bytes: number | null): string { + if (bytes === null || bytes === undefined) return "—"; + if (bytes === 0) return "0 B"; + const units = ["B", "KB", "MB", "GB", "TB"]; + const i = Math.floor(Math.log(bytes) / Math.log(1024)); + const value = bytes / Math.pow(1024, i); + return `${value.toFixed(i === 0 ? 0 : 1)} ${units[i]}`; +} + +function formatDuration(startedAt: number, endedAt: number | null): string { + if (endedAt === null || endedAt === undefined) return "Active"; + const durationSec = endedAt - startedAt; + if (durationSec < 0) return "—"; + if (durationSec < 60) return `${durationSec}s`; + if (durationSec < 3600) { + const m = Math.floor(durationSec / 60); + const s = durationSec % 60; + return `${m}m ${s}s`; + } + const h = Math.floor(durationSec / 3600); + const m = Math.floor((durationSec % 3600) / 60); + return `${h}h ${m}m`; +} + +export default function ConnectionLogsPage() { + const router = useRouter(); + const api = createApiClient(useEnvContext()); + const t = useTranslations(); + const { orgId } = useParams(); + const searchParams = useSearchParams(); + + const { isPaidUser } = usePaidStatus(); + + const [rows, setRows] = useState([]); + const [isRefreshing, setIsRefreshing] = useState(false); + const [isExporting, startTransition] = useTransition(); + const [filterAttributes, setFilterAttributes] = useState<{ + protocols: string[]; + }>({ + protocols: [] + }); + + // Filter states - unified object for all filters + const [filters, setFilters] = useState<{ + protocol?: string; + }>({ + protocol: searchParams.get("protocol") || undefined + }); + + // Pagination state + const [totalCount, setTotalCount] = useState(0); + const [currentPage, setCurrentPage] = useState(0); + const [isLoading, setIsLoading] = useState(false); + + // Initialize page size from storage or default + const [pageSize, setPageSize] = useStoredPageSize( + "connection-audit-logs", + 20 + ); + + // Set default date range to last 7 days + const getDefaultDateRange = () => { + // if the time is in the url params, use that instead + const startParam = searchParams.get("start"); + const endParam = searchParams.get("end"); + if (startParam && endParam) { + return { + startDate: { + date: new Date(startParam) + }, + endDate: { + date: new Date(endParam) + } + }; + } + + const now = new Date(); + const lastWeek = getSevenDaysAgo(); + + return { + startDate: { + date: lastWeek + }, + endDate: { + date: now + } + }; + }; + + const [dateRange, setDateRange] = useState<{ + startDate: DateTimeValue; + endDate: DateTimeValue; + }>(getDefaultDateRange()); + + // Trigger search with default values on component mount + useEffect(() => { + if (build === "oss") { + return; + } + const defaultRange = getDefaultDateRange(); + queryDateTime( + defaultRange.startDate, + defaultRange.endDate, + 0, + pageSize + ); + }, [orgId]); // Re-run if orgId changes + + const handleDateRangeChange = ( + startDate: DateTimeValue, + endDate: DateTimeValue + ) => { + setDateRange({ startDate, endDate }); + setCurrentPage(0); // Reset to first page when filtering + // put the search params in the url for the time + updateUrlParamsForAllFilters({ + start: startDate.date?.toISOString() || "", + end: endDate.date?.toISOString() || "" + }); + + queryDateTime(startDate, endDate, 0, pageSize); + }; + + // Handle page changes + const handlePageChange = (newPage: number) => { + setCurrentPage(newPage); + queryDateTime( + dateRange.startDate, + dateRange.endDate, + newPage, + pageSize + ); + }; + + // Handle page size changes + const handlePageSizeChange = (newPageSize: number) => { + setPageSize(newPageSize); + setCurrentPage(0); // Reset to first page when changing page size + queryDateTime(dateRange.startDate, dateRange.endDate, 0, newPageSize); + }; + + // Handle filter changes generically + const handleFilterChange = ( + filterType: keyof typeof filters, + value: string | undefined + ) => { + // Create new filters object with updated value + const newFilters = { + ...filters, + [filterType]: value + }; + + setFilters(newFilters); + setCurrentPage(0); // Reset to first page when filtering + + // Update URL params + updateUrlParamsForAllFilters(newFilters); + + // Trigger new query with updated filters (pass directly to avoid async state issues) + queryDateTime( + dateRange.startDate, + dateRange.endDate, + 0, + pageSize, + newFilters + ); + }; + + const updateUrlParamsForAllFilters = ( + newFilters: + | typeof filters + | { + start: string; + end: string; + } + ) => { + const params = new URLSearchParams(searchParams); + Object.entries(newFilters).forEach(([key, value]) => { + if (value) { + params.set(key, value); + } else { + params.delete(key); + } + }); + router.replace(`?${params.toString()}`, { scroll: false }); + }; + + const queryDateTime = async ( + startDate: DateTimeValue, + endDate: DateTimeValue, + page: number = currentPage, + size: number = pageSize, + filtersParam?: { + protocol?: string; + } + ) => { + console.log("Date range changed:", { startDate, endDate, page, size }); + if (!isPaidUser(tierMatrix.connectionLogs)) { + console.log( + "Access denied: subscription inactive or license locked" + ); + return; + } + setIsLoading(true); + + try { + // Use the provided filters or fall back to current state + const activeFilters = filtersParam || filters; + + // Convert the date/time values to API parameters + const params: any = { + limit: size, + offset: page * size, + ...activeFilters + }; + + if (startDate?.date) { + const startDateTime = new Date(startDate.date); + if (startDate.time) { + const [hours, minutes, seconds] = startDate.time + .split(":") + .map(Number); + startDateTime.setHours(hours, minutes, seconds || 0); + } + params.timeStart = startDateTime.toISOString(); + } + + if (endDate?.date) { + const endDateTime = new Date(endDate.date); + if (endDate.time) { + const [hours, minutes, seconds] = endDate.time + .split(":") + .map(Number); + endDateTime.setHours(hours, minutes, seconds || 0); + } else { + // If no time is specified, set to NOW + const now = new Date(); + endDateTime.setHours( + now.getHours(), + now.getMinutes(), + now.getSeconds(), + now.getMilliseconds() + ); + } + params.timeEnd = endDateTime.toISOString(); + } + + const res = await api.get(`/org/${orgId}/logs/connection`, { + params + }); + if (res.status === 200) { + setRows(res.data.data.log || []); + setTotalCount(res.data.data.pagination?.total || 0); + setFilterAttributes(res.data.data.filterAttributes); + console.log("Fetched connection logs:", res.data); + } + } catch (error) { + toast({ + title: t("error"), + description: t("Failed to filter logs"), + variant: "destructive" + }); + } finally { + setIsLoading(false); + } + }; + + const refreshData = async () => { + console.log("Data refreshed"); + setIsRefreshing(true); + try { + // Refresh data with current date range and pagination + await queryDateTime( + dateRange.startDate, + dateRange.endDate, + currentPage, + pageSize + ); + } catch (error) { + toast({ + title: t("error"), + description: t("refreshError"), + variant: "destructive" + }); + } finally { + setIsRefreshing(false); + } + }; + + const exportData = async () => { + try { + // Prepare query params for export + const params: any = { + timeStart: dateRange.startDate?.date + ? new Date(dateRange.startDate.date).toISOString() + : undefined, + timeEnd: dateRange.endDate?.date + ? new Date(dateRange.endDate.date).toISOString() + : undefined, + ...filters + }; + + const response = await api.get( + `/org/${orgId}/logs/connection/export`, + { + responseType: "blob", + params + } + ); + + // Create a URL for the blob and trigger a download + const url = window.URL.createObjectURL(new Blob([response.data])); + const link = document.createElement("a"); + link.href = url; + const epoch = Math.floor(Date.now() / 1000); + link.setAttribute( + "download", + `connection-audit-logs-${orgId}-${epoch}.csv` + ); + document.body.appendChild(link); + link.click(); + link.parentNode?.removeChild(link); + } catch (error) { + let apiErrorMessage: string | null = null; + if (axios.isAxiosError(error) && error.response) { + const data = error.response.data; + + if (data instanceof Blob && data.type === "application/json") { + // Parse the Blob as JSON + const text = await data.text(); + const errorData = JSON.parse(text); + apiErrorMessage = errorData.message; + } + } + toast({ + title: t("error"), + description: apiErrorMessage ?? t("exportError"), + variant: "destructive" + }); + } + }; + + const columns: ColumnDef[] = [ + { + accessorKey: "startedAt", + header: ({ column }) => { + return t("timestamp"); + }, + cell: ({ row }) => { + return ( +
+ {new Date( + row.original.startedAt * 1000 + ).toLocaleString()} +
+ ); + } + }, + { + accessorKey: "protocol", + header: ({ column }) => { + return ( +
+ {t("protocol")} + ({ + label: protocol.toUpperCase(), + value: protocol + }) + )} + selectedValue={filters.protocol} + onValueChange={(value) => + handleFilterChange("protocol", value) + } + searchPlaceholder="Search..." + emptyMessage="None found" + /> +
+ ); + }, + cell: ({ row }) => { + return ( + + {row.original.protocol?.toUpperCase()} + + ); + } + }, + { + accessorKey: "resourceName", + header: ({ column }) => { + return t("resource"); + }, + cell: ({ row }) => { + return ( + + {row.original.resourceName ?? "—"} + + ); + } + }, + { + accessorKey: "sourceAddr", + header: ({ column }) => { + return t("sourceAddress"); + }, + cell: ({ row }) => { + return ( + + {row.original.sourceAddr} + + ); + } + }, + { + accessorKey: "destAddr", + header: ({ column }) => { + return t("destinationAddress"); + }, + cell: ({ row }) => { + return ( + + {row.original.destAddr} + + ); + } + }, + { + accessorKey: "duration", + header: ({ column }) => { + return t("duration"); + }, + cell: ({ row }) => { + return ( + + {formatDuration( + row.original.startedAt, + row.original.endedAt + )} + + ); + } + } + ]; + + const renderExpandedRow = (row: any) => { + return ( +
+
+
+
+ + Connection Details +
+
+ Session ID:{" "} + + {row.sessionId ?? "—"} + +
+
+ Protocol:{" "} + {row.protocol?.toUpperCase() ?? "—"} +
+
+ Source:{" "} + + {row.sourceAddr ?? "—"} + +
+
+ Destination:{" "} + + {row.destAddr ?? "—"} + +
+
+ Started At:{" "} + {row.startedAt + ? new Date( + row.startedAt * 1000 + ).toLocaleString() + : "—"} +
+
+ Ended At:{" "} + {row.endedAt + ? new Date( + row.endedAt * 1000 + ).toLocaleString() + : "Active"} +
+
+ Duration:{" "} + {formatDuration(row.startedAt, row.endedAt)} +
+
+
+
+ + Resource & Site +
+
+ Resource:{" "} + {row.resourceName ?? "—"} + {row.resourceNiceId && ( + + ({row.resourceNiceId}) + + )} +
+
+ Site: {row.siteName ?? "—"} + {row.siteNiceId && ( + + ({row.siteNiceId}) + + )} +
+
+ Site ID: {row.siteId ?? "—"} +
+
+ Resource ID:{" "} + {row.siteResourceId ?? "—"} +
+
+
+
+ + Client & Transfer +
+
+ Client: {row.clientName ?? "—"} + {row.clientId && ( + + (ID: {row.clientId}) + + )} +
+
+ User ID: {row.userId ?? "—"} +
+
+ Bytes Sent (TX):{" "} + {formatBytes(row.bytesTx)} +
+
+ Bytes Received (RX):{" "} + {formatBytes(row.bytesRx)} +
+
+ Total Transfer:{" "} + {formatBytes( + (row.bytesTx ?? 0) + (row.bytesRx ?? 0) + )} +
+
+
+
+ ); + }; + + return ( + <> + + + + + startTransition(exportData)} + isExporting={isExporting} + onDateRangeChange={handleDateRangeChange} + dateRange={{ + start: dateRange.startDate, + end: dateRange.endDate + }} + defaultSort={{ + id: "startedAt", + desc: true + }} + // Server-side pagination props + totalCount={totalCount} + currentPage={currentPage} + pageSize={pageSize} + onPageChange={handlePageChange} + onPageSizeChange={handlePageSizeChange} + isLoading={isLoading} + // Row expansion props + expandable={true} + renderExpandedRow={renderExpandedRow} + disabled={ + !isPaidUser(tierMatrix.connectionLogs) || build === "oss" + } + /> + + ); +} \ No newline at end of file diff --git a/src/app/navigation.tsx b/src/app/navigation.tsx index 0066721db..0a09214e3 100644 --- a/src/app/navigation.tsx +++ b/src/app/navigation.tsx @@ -3,6 +3,7 @@ import { Env } from "@app/lib/types/env"; import { build } from "@server/build"; import { Building2, + Cable, ChartLine, Combine, CreditCard, @@ -189,6 +190,11 @@ export const orgNavSections = ( title: "sidebarLogsAction", href: "/{orgId}/settings/logs/action", icon: + }, + { + title: "sidebarLogsConnection", + href: "/{orgId}/settings/logs/connection", + icon: } ] : []) From 2c6e9507b55efdedcf88f77a054475b8dd9daa51 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 23 Mar 2026 21:41:53 -0700 Subject: [PATCH 05/24] Connection log page working --- server/lib/ip.ts | 10 +- .../auditLogs/queryConnectionAuditLog.ts | 156 ++++++++++++++- .../newt/handleConnectionLogMessage.ts | 16 +- server/routers/auditLogs/types.ts | 16 ++ .../[orgId]/settings/logs/connection/page.tsx | 180 ++++++++++++++++-- 5 files changed, 349 insertions(+), 29 deletions(-) diff --git a/server/lib/ip.ts b/server/lib/ip.ts index 3a29b8661..7f829bcef 100644 --- a/server/lib/ip.ts +++ b/server/lib/ip.ts @@ -581,6 +581,7 @@ export type SubnetProxyTargetV2 = { max: number; protocol: "tcp" | "udp"; }[]; + resourceId?: number; }; export function generateSubnetProxyTargetV2( @@ -617,7 +618,8 @@ export function generateSubnetProxyTargetV2( sourcePrefixes: [], destPrefix: destination, portRange, - disableIcmp + disableIcmp, + resourceId: siteResource.siteResourceId, }; } @@ -628,7 +630,8 @@ export function generateSubnetProxyTargetV2( destPrefix: `${siteResource.aliasAddress}/32`, rewriteTo: destination, portRange, - disableIcmp + disableIcmp, + resourceId: siteResource.siteResourceId, }; } } else if (siteResource.mode == "cidr") { @@ -636,7 +639,8 @@ export function generateSubnetProxyTargetV2( sourcePrefixes: [], destPrefix: siteResource.destination, portRange, - disableIcmp + disableIcmp, + resourceId: siteResource.siteResourceId, }; } diff --git a/server/private/routers/auditLogs/queryConnectionAuditLog.ts b/server/private/routers/auditLogs/queryConnectionAuditLog.ts index f321444cd..b638ed488 100644 --- a/server/private/routers/auditLogs/queryConnectionAuditLog.ts +++ b/server/private/routers/auditLogs/queryConnectionAuditLog.ts @@ -17,6 +17,7 @@ import { siteResources, sites, clients, + users, primaryDb } from "@server/db"; import { registry } from "@server/openApi"; @@ -193,6 +194,13 @@ async function enrichWithDetails( .filter((id): id is number => id !== null && id !== undefined) ) ]; + const userIds = [ + ...new Set( + logs + .map((log) => log.userId) + .filter((id): id is string => id !== null && id !== undefined) + ) + ]; // Fetch resource details from main database const resourceMap = new Map< @@ -235,18 +243,46 @@ async function enrichWithDetails( } // Fetch client details from main database - const clientMap = new Map(); + const clientMap = new Map< + number, + { name: string; niceId: string; type: string } + >(); if (clientIds.length > 0) { const clientDetails = await primaryDb .select({ clientId: clients.clientId, - name: clients.name + name: clients.name, + niceId: clients.niceId, + type: clients.type }) .from(clients) .where(inArray(clients.clientId, clientIds)); for (const c of clientDetails) { - clientMap.set(c.clientId, { name: c.name }); + clientMap.set(c.clientId, { + name: c.name, + niceId: c.niceId, + type: c.type + }); + } + } + + // Fetch user details from main database + const userMap = new Map< + string, + { email: string | null } + >(); + if (userIds.length > 0) { + const userDetails = await primaryDb + .select({ + userId: users.userId, + email: users.email + }) + .from(users) + .where(inArray(users.userId, userIds)); + + for (const u of userDetails) { + userMap.set(u.userId, { email: u.email }); } } @@ -267,6 +303,15 @@ async function enrichWithDetails( : null, clientName: log.clientId ? clientMap.get(log.clientId)?.name ?? null + : null, + clientNiceId: log.clientId + ? clientMap.get(log.clientId)?.niceId ?? null + : null, + clientType: log.clientId + ? clientMap.get(log.clientId)?.type ?? null + : null, + userEmail: log.userId + ? userMap.get(log.userId)?.email ?? null : null })); } @@ -290,10 +335,111 @@ async function queryUniqueFilterAttributes( .from(connectionAuditLog) .where(baseConditions); + // Get unique destination addresses + const uniqueDestAddrs = await logsDb + .selectDistinct({ + destAddr: connectionAuditLog.destAddr + }) + .from(connectionAuditLog) + .where(baseConditions); + + // Get unique client IDs + const uniqueClients = await logsDb + .selectDistinct({ + clientId: connectionAuditLog.clientId + }) + .from(connectionAuditLog) + .where(baseConditions); + + // Get unique resource IDs + const uniqueResources = await logsDb + .selectDistinct({ + siteResourceId: connectionAuditLog.siteResourceId + }) + .from(connectionAuditLog) + .where(baseConditions); + + // Get unique user IDs + const uniqueUsers = await logsDb + .selectDistinct({ + userId: connectionAuditLog.userId + }) + .from(connectionAuditLog) + .where(baseConditions); + + // Enrich client IDs with names from main database + const clientIds = uniqueClients + .map((row) => row.clientId) + .filter((id): id is number => id !== null); + + let clientsWithNames: Array<{ id: number; name: string }> = []; + if (clientIds.length > 0) { + const clientDetails = await primaryDb + .select({ + clientId: clients.clientId, + name: clients.name + }) + .from(clients) + .where(inArray(clients.clientId, clientIds)); + + clientsWithNames = clientDetails.map((c) => ({ + id: c.clientId, + name: c.name + })); + } + + // Enrich resource IDs with names from main database + const resourceIds = uniqueResources + .map((row) => row.siteResourceId) + .filter((id): id is number => id !== null); + + let resourcesWithNames: Array<{ id: number; name: string | null }> = []; + if (resourceIds.length > 0) { + const resourceDetails = await primaryDb + .select({ + siteResourceId: siteResources.siteResourceId, + name: siteResources.name + }) + .from(siteResources) + .where(inArray(siteResources.siteResourceId, resourceIds)); + + resourcesWithNames = resourceDetails.map((r) => ({ + id: r.siteResourceId, + name: r.name + })); + } + + // Enrich user IDs with emails from main database + const userIdsList = uniqueUsers + .map((row) => row.userId) + .filter((id): id is string => id !== null); + + let usersWithEmails: Array<{ id: string; email: string | null }> = []; + if (userIdsList.length > 0) { + const userDetails = await primaryDb + .select({ + userId: users.userId, + email: users.email + }) + .from(users) + .where(inArray(users.userId, userIdsList)); + + usersWithEmails = userDetails.map((u) => ({ + id: u.userId, + email: u.email + })); + } + return { protocols: uniqueProtocols .map((row) => row.protocol) - .filter((protocol): protocol is string => protocol !== null) + .filter((protocol): protocol is string => protocol !== null), + destAddrs: uniqueDestAddrs + .map((row) => row.destAddr) + .filter((addr): addr is string => addr !== null), + clients: clientsWithNames, + resources: resourcesWithNames, + users: usersWithEmails }; } @@ -342,7 +488,7 @@ export async function queryConnectionAuditLogs( const logsRaw = await baseQuery.limit(data.limit).offset(data.offset); - // Enrich with resource, site, and client details + // Enrich with resource, site, client, and user details const log = await enrichWithDetails(logsRaw); const totalCountResult = await countConnectionQuery(data); diff --git a/server/private/routers/newt/handleConnectionLogMessage.ts b/server/private/routers/newt/handleConnectionLogMessage.ts index 164c14488..2ac7153b5 100644 --- a/server/private/routers/newt/handleConnectionLogMessage.ts +++ b/server/private/routers/newt/handleConnectionLogMessage.ts @@ -277,6 +277,8 @@ export const handleConnectionLogMessage: MessageHandler = async (context) => { return; } + logger.debug(`Sessions: ${JSON.stringify(sessions)}`) + // Build a map from sourceAddr → { clientId, userId } by querying clients // whose subnet field matches exactly. Client subnets are stored with the // org's CIDR suffix (e.g. "100.90.128.5/16"), so we reconstruct that from @@ -295,9 +297,15 @@ export const handleConnectionLogMessage: MessageHandler = async (context) => { if (uniqueSourceAddrs.size > 0) { // Construct the exact subnet strings as stored in the DB const subnetQueries = Array.from(uniqueSourceAddrs).map( - (addr) => `${addr}${cidrSuffix}` + (addr) => { + // Strip port if present (e.g. "100.90.128.1:38004" → "100.90.128.1") + const ip = addr.includes(":") ? addr.split(":")[0] : addr; + return `${ip}${cidrSuffix}`; + } ); + logger.debug(`Subnet queries: ${JSON.stringify(subnetQueries)}`); + const matchedClients = await db .select({ clientId: clients.clientId, @@ -314,6 +322,7 @@ export const handleConnectionLogMessage: MessageHandler = async (context) => { for (const c of matchedClients) { const ip = c.subnet.split("/")[0]; + logger.debug(`Client ${c.clientId} subnet ${c.subnet} matches ${ip}`); ipToClient.set(ip, { clientId: c.clientId, userId: c.userId }); } } @@ -346,7 +355,10 @@ export const handleConnectionLogMessage: MessageHandler = async (context) => { // Match the source address to a client. The sourceAddr is the // client's IP on the WireGuard network, which corresponds to the IP // portion of the client's subnet CIDR (e.g. "100.90.128.5/24"). - const clientInfo = ipToClient.get(session.sourceAddr) ?? null; + // Strip port if present (e.g. "100.90.128.1:38004" → "100.90.128.1") + const sourceIp = session.sourceAddr.includes(":") ? session.sourceAddr.split(":")[0] : session.sourceAddr; + const clientInfo = ipToClient.get(sourceIp) ?? null; + buffer.push({ sessionId: session.sessionId, diff --git a/server/routers/auditLogs/types.ts b/server/routers/auditLogs/types.ts index 20e11e17b..4c278cba5 100644 --- a/server/routers/auditLogs/types.ts +++ b/server/routers/auditLogs/types.ts @@ -112,6 +112,9 @@ export type QueryConnectionAuditLogResponse = { siteName: string | null; siteNiceId: string | null; clientName: string | null; + clientNiceId: string | null; + clientType: string | null; + userEmail: string | null; }[]; pagination: { total: number; @@ -120,5 +123,18 @@ export type QueryConnectionAuditLogResponse = { }; filterAttributes: { protocols: string[]; + destAddrs: string[]; + clients: { + id: number; + name: string; + }[]; + resources: { + id: number; + name: string | null; + }[]; + users: { + id: string; + email: string | null; + }[]; }; }; diff --git a/src/app/[orgId]/settings/logs/connection/page.tsx b/src/app/[orgId]/settings/logs/connection/page.tsx index 737b1efd7..7ac137467 100644 --- a/src/app/[orgId]/settings/logs/connection/page.tsx +++ b/src/app/[orgId]/settings/logs/connection/page.tsx @@ -1,4 +1,5 @@ "use client"; +import { Button } from "@app/components/ui/button"; import { ColumnFilter } from "@app/components/ColumnFilter"; import { DateTimeValue } from "@app/components/DateTimePicker"; import { LogDataTable } from "@app/components/LogDataTable"; @@ -14,7 +15,8 @@ import { build } from "@server/build"; import { tierMatrix } from "@server/lib/billing/tierMatrix"; import { ColumnDef } from "@tanstack/react-table"; import axios from "axios"; -import { Cable, Monitor, Server } from "lucide-react"; +import { ArrowUpRight, Laptop, User } from "lucide-react"; +import Link from "next/link"; import { useTranslations } from "next-intl"; import { useParams, useRouter, useSearchParams } from "next/navigation"; import { useEffect, useState, useTransition } from "react"; @@ -57,15 +59,31 @@ export default function ConnectionLogsPage() { const [isExporting, startTransition] = useTransition(); const [filterAttributes, setFilterAttributes] = useState<{ protocols: string[]; + destAddrs: string[]; + clients: { id: number; name: string }[]; + resources: { id: number; name: string | null }[]; + users: { id: string; email: string | null }[]; }>({ - protocols: [] + protocols: [], + destAddrs: [], + clients: [], + resources: [], + users: [] }); // Filter states - unified object for all filters const [filters, setFilters] = useState<{ protocol?: string; + destAddr?: string; + clientId?: string; + siteResourceId?: string; + userId?: string; }>({ - protocol: searchParams.get("protocol") || undefined + protocol: searchParams.get("protocol") || undefined, + destAddr: searchParams.get("destAddr") || undefined, + clientId: searchParams.get("clientId") || undefined, + siteResourceId: searchParams.get("siteResourceId") || undefined, + userId: searchParams.get("userId") || undefined }); // Pagination state @@ -211,9 +229,7 @@ export default function ConnectionLogsPage() { endDate: DateTimeValue, page: number = currentPage, size: number = pageSize, - filtersParam?: { - protocol?: string; - } + filtersParam?: typeof filters ) => { console.log("Date range changed:", { startDate, endDate, page, size }); if (!isPaidUser(tierMatrix.connectionLogs)) { @@ -411,9 +427,41 @@ export default function ConnectionLogsPage() { { accessorKey: "resourceName", header: ({ column }) => { - return t("resource"); + return ( +
+ {t("resource")} + ({ + value: res.id.toString(), + label: res.name || "Unnamed Resource" + }))} + selectedValue={filters.siteResourceId} + onValueChange={(value) => + handleFilterChange("siteResourceId", value) + } + searchPlaceholder="Search..." + emptyMessage="None found" + /> +
+ ); }, cell: ({ row }) => { + if (row.original.resourceName && row.original.resourceNiceId) { + return ( + + + + ); + } return ( {row.original.resourceName ?? "—"} @@ -421,6 +469,86 @@ export default function ConnectionLogsPage() { ); } }, + { + accessorKey: "clientName", + header: ({ column }) => { + return ( +
+ {t("client")} + ({ + value: c.id.toString(), + label: c.name + }))} + selectedValue={filters.clientId} + onValueChange={(value) => + handleFilterChange("clientId", value) + } + searchPlaceholder="Search..." + emptyMessage="None found" + /> +
+ ); + }, + cell: ({ row }) => { + const clientType = row.original.clientType === "olm" ? "machine" : "user"; + if (row.original.clientName && row.original.clientNiceId) { + return ( + + + + ); + } + return ( + + {row.original.clientName ?? "—"} + + ); + } + }, + { + accessorKey: "userEmail", + header: ({ column }) => { + return ( +
+ {t("user")} + ({ + value: u.id, + label: u.email || u.id + }))} + selectedValue={filters.userId} + onValueChange={(value) => + handleFilterChange("userId", value) + } + searchPlaceholder="Search..." + emptyMessage="None found" + /> +
+ ); + }, + cell: ({ row }) => { + if (row.original.userEmail || row.original.userId) { + return ( + + + {row.original.userEmail ?? row.original.userId} + + ); + } + return ; + } + }, { accessorKey: "sourceAddr", header: ({ column }) => { @@ -437,7 +565,23 @@ export default function ConnectionLogsPage() { { accessorKey: "destAddr", header: ({ column }) => { - return t("destinationAddress"); + return ( +
+ {t("destinationAddress")} + ({ + value: addr, + label: addr + }))} + selectedValue={filters.destAddr} + onValueChange={(value) => + handleFilterChange("destAddr", value) + } + searchPlaceholder="Search..." + emptyMessage="None found" + /> +
+ ); }, cell: ({ row }) => { return ( @@ -470,10 +614,9 @@ export default function ConnectionLogsPage() {
-
- + {/*
Connection Details -
+
*/}
Session ID:{" "} @@ -518,10 +661,9 @@ export default function ConnectionLogsPage() {
-
- + {/*
Resource & Site -
+
*/}
Resource:{" "} {row.resourceName ?? "—"} @@ -548,10 +690,9 @@ export default function ConnectionLogsPage() {
-
- + {/*
Client & Transfer -
+
*/}
Client: {row.clientName ?? "—"} {row.clientId && ( @@ -561,7 +702,8 @@ export default function ConnectionLogsPage() { )}
- User ID: {row.userId ?? "—"} + User:{" "} + {row.userEmail ?? row.userId ?? "—"}
Bytes Sent (TX):{" "} @@ -627,4 +769,4 @@ export default function ConnectionLogsPage() { /> ); -} \ No newline at end of file +} From f9bff5954f190ce697b9d4ed6ed3b3992343e854 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 23 Mar 2026 21:49:22 -0700 Subject: [PATCH 06/24] Add filters and refine table and query --- .../[orgId]/settings/logs/connection/page.tsx | 78 ++++++++----------- 1 file changed, 33 insertions(+), 45 deletions(-) diff --git a/src/app/[orgId]/settings/logs/connection/page.tsx b/src/app/[orgId]/settings/logs/connection/page.tsx index 7ac137467..1d93e0658 100644 --- a/src/app/[orgId]/settings/logs/connection/page.tsx +++ b/src/app/[orgId]/settings/logs/connection/page.tsx @@ -639,6 +639,31 @@ export default function ConnectionLogsPage() { {row.destAddr ?? "—"}
+
+
+ {/*
+ Resource & Site +
*/} + {/*
+ Resource:{" "} + {row.resourceName ?? "—"} + {row.resourceNiceId && ( + + ({row.resourceNiceId}) + + )} +
*/} +
+ Site: {row.siteName ?? "—"} + {row.siteNiceId && ( + + ({row.siteNiceId}) + + )} +
+
+ Site ID: {row.siteId ?? "—"} +
Started At:{" "} {row.startedAt @@ -659,66 +684,29 @@ export default function ConnectionLogsPage() { Duration:{" "} {formatDuration(row.startedAt, row.endedAt)}
-
-
- {/*
- Resource & Site -
*/} -
- Resource:{" "} - {row.resourceName ?? "—"} - {row.resourceNiceId && ( - - ({row.resourceNiceId}) - - )} -
-
- Site: {row.siteName ?? "—"} - {row.siteNiceId && ( - - ({row.siteNiceId}) - - )} -
-
- Site ID: {row.siteId ?? "—"} -
-
+ {/*
Resource ID:{" "} {row.siteResourceId ?? "—"} -
+
*/}
{/*
Client & Transfer
*/} -
- Client: {row.clientName ?? "—"} - {row.clientId && ( - - (ID: {row.clientId}) - - )} -
-
- User:{" "} - {row.userEmail ?? row.userId ?? "—"} -
-
+ {/*
Bytes Sent (TX):{" "} {formatBytes(row.bytesTx)} -
-
+
*/} + {/*
Bytes Received (RX):{" "} {formatBytes(row.bytesRx)} -
-
+
*/} + {/*
Total Transfer:{" "} {formatBytes( (row.bytesTx ?? 0) + (row.bytesRx ?? 0) )} -
+
*/}
From 7b78b914493cb1bf2f5066a05296b787f1dd7517 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 23 Mar 2026 22:00:53 -0700 Subject: [PATCH 07/24] Fix resource link --- src/app/[orgId]/settings/logs/connection/page.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/app/[orgId]/settings/logs/connection/page.tsx b/src/app/[orgId]/settings/logs/connection/page.tsx index 1d93e0658..dff42faac 100644 --- a/src/app/[orgId]/settings/logs/connection/page.tsx +++ b/src/app/[orgId]/settings/logs/connection/page.tsx @@ -449,7 +449,7 @@ export default function ConnectionLogsPage() { if (row.original.resourceName && row.original.resourceNiceId) { return ( + + + + ) : ( + + + + )} + + + + ); +} diff --git a/src/components/SiteProvisioningKeysTable.tsx b/src/components/SiteProvisioningKeysTable.tsx new file mode 100644 index 000000000..3fb3eb872 --- /dev/null +++ b/src/components/SiteProvisioningKeysTable.tsx @@ -0,0 +1,216 @@ +"use client"; + +import { + DataTable, + ExtendedColumnDef +} from "@app/components/ui/data-table"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger +} from "@app/components/ui/dropdown-menu"; +import { Button } from "@app/components/ui/button"; +import { ArrowUpDown, MoreHorizontal } from "lucide-react"; +import { useRouter } from "next/navigation"; +import { useEffect, useState } from "react"; +import CreateSiteProvisioningKeyCredenza from "@app/components/CreateSiteProvisioningKeyCredenza"; +import ConfirmDeleteDialog from "@app/components/ConfirmDeleteDialog"; +import { toast } from "@app/hooks/useToast"; +import { formatAxiosError } from "@app/lib/api"; +import { createApiClient } from "@app/lib/api"; +import { useEnvContext } from "@app/hooks/useEnvContext"; +import moment from "moment"; +import { useTranslations } from "next-intl"; + +export type SiteProvisioningKeyRow = { + id: string; + key: string; + name: string; + createdAt: string; +}; + +type SiteProvisioningKeysTableProps = { + keys: SiteProvisioningKeyRow[]; + orgId: string; +}; + +export default function SiteProvisioningKeysTable({ + keys, + orgId +}: SiteProvisioningKeysTableProps) { + const router = useRouter(); + const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false); + const [selected, setSelected] = useState( + null + ); + const [rows, setRows] = useState(keys); + const api = createApiClient(useEnvContext()); + const t = useTranslations(); + const [isRefreshing, setIsRefreshing] = useState(false); + const [createOpen, setCreateOpen] = useState(false); + + useEffect(() => { + setRows(keys); + }, [keys]); + + const refreshData = async () => { + setIsRefreshing(true); + try { + await new Promise((resolve) => setTimeout(resolve, 200)); + router.refresh(); + } catch (error) { + toast({ + title: t("error"), + description: t("refreshError"), + variant: "destructive" + }); + } finally { + setIsRefreshing(false); + } + }; + + const deleteKey = async (siteProvisioningKeyId: string) => { + try { + await api.delete( + `/org/${orgId}/site-provisioning-key/${siteProvisioningKeyId}` + ); + router.refresh(); + setIsDeleteModalOpen(false); + setSelected(null); + setRows((prev) => prev.filter((row) => row.id !== siteProvisioningKeyId)); + } catch (e) { + console.error(t("provisioningKeysErrorDelete"), e); + toast({ + variant: "destructive", + title: t("provisioningKeysErrorDelete"), + description: formatAxiosError( + e, + t("provisioningKeysErrorDeleteMessage") + ) + }); + throw e; + } + }; + + const columns: ExtendedColumnDef[] = [ + { + accessorKey: "name", + enableHiding: false, + friendlyName: t("name"), + header: ({ column }) => { + return ( + + ); + } + }, + { + accessorKey: "key", + friendlyName: t("key"), + header: () => {t("key")}, + cell: ({ row }) => { + const r = row.original; + return {r.key}; + } + }, + { + accessorKey: "createdAt", + friendlyName: t("createdAt"), + header: () => {t("createdAt")}, + cell: ({ row }) => { + const r = row.original; + return {moment(r.createdAt).format("lll")}; + } + }, + { + id: "actions", + enableHiding: false, + header: () => , + cell: ({ row }) => { + const r = row.original; + return ( +
+ + + + + + { + setSelected(r); + setIsDeleteModalOpen(true); + }} + > + + {t("delete")} + + + + +
+ ); + } + } + ]; + + return ( + <> + + + {selected && ( + { + setIsDeleteModalOpen(val); + if (!val) { + setSelected(null); + } + }} + dialog={ +
+

{t("provisioningKeysQuestionRemove")}

+

{t("provisioningKeysMessageRemove")}

+
+ } + buttonText={t("provisioningKeysDeleteConfirm")} + onConfirm={async () => deleteKey(selected.id)} + string={selected.name} + title={t("provisioningKeysDelete")} + /> + )} + + setCreateOpen(true)} + onRefresh={refreshData} + isRefreshing={isRefreshing} + addButtonText={t("provisioningKeysAdd")} + enableColumnVisibility={true} + stickyLeftColumn="name" + stickyRightColumn="actions" + /> + + ); +} From b2eab95a3b724ea947cecb9a6c57411f021b2271 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 24 Mar 2026 18:17:33 -0700 Subject: [PATCH 10/24] Pass at first endpoints --- server/routers/external.ts | 23 ++ .../routers/newt/createNewtProvisioningKey.ts | 108 ++++++++ server/routers/newt/index.ts | 2 + server/routers/newt/registerNewt.ts | 240 ++++++++++++++++++ 4 files changed, 373 insertions(+) create mode 100644 server/routers/newt/createNewtProvisioningKey.ts create mode 100644 server/routers/newt/registerNewt.ts diff --git a/server/routers/external.ts b/server/routers/external.ts index 45ab58bba..c32a7a8e9 100644 --- a/server/routers/external.ts +++ b/server/routers/external.ts @@ -102,6 +102,13 @@ authenticated.put( logActionAudit(ActionsEnum.createSite), site.createSite ); + +authenticated.put( + "/org/:orgId/newt/provisioning-key", + verifyOrgAccess, + verifyUserHasAction(ActionsEnum.createSite), + newt.createNewtProvisioningKey +); authenticated.get( "/org/:orgId/sites", verifyOrgAccess, @@ -1202,6 +1209,22 @@ authRouter.post( }), newt.getNewtToken ); + +authRouter.post( + "/newt/register", + rateLimit({ + windowMs: 15 * 60 * 1000, + max: 30, + keyGenerator: (req) => + `newtRegister:${req.body.provisioningKey?.split(".")[0] || ipKeyGenerator(req.ip || "")}`, + handler: (req, res, next) => { + const message = `You can only register a newt ${30} times every ${15} minutes. Please try again later.`; + return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message)); + }, + store: createStore() + }), + newt.registerNewt +); authRouter.post( "/olm/get-token", rateLimit({ diff --git a/server/routers/newt/createNewtProvisioningKey.ts b/server/routers/newt/createNewtProvisioningKey.ts new file mode 100644 index 000000000..2af4d166d --- /dev/null +++ b/server/routers/newt/createNewtProvisioningKey.ts @@ -0,0 +1,108 @@ +import { Request, Response, NextFunction } from "express"; +import { z } from "zod"; +import { db } from "@server/db"; +import { newtProvisioningKeys, orgs } from "@server/db"; +import response from "@server/lib/response"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { eq } from "drizzle-orm"; +import { fromError } from "zod-validation-error"; +import { + generateId, + generateIdFromEntropySize +} from "@server/auth/sessions/app"; +import { hashPassword } from "@server/auth/password"; +import moment from "moment"; + +const paramsSchema = z.object({ + orgId: z.string().nonempty() +}); + +const bodySchema = z.object({ + expiresAt: z.number().int().positive().optional() // optional Unix timestamp (ms) +}); + +export type CreateNewtProvisioningKeyBody = z.infer; + +export type CreateNewtProvisioningKeyResponse = { + provisioningKeyId: string; + provisioningKey: string; // returned only once: "id.secret" + lastChars: string; + createdAt: string; + expiresAt: number | null; +}; + +export async function createNewtProvisioningKey( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedParams = paramsSchema.safeParse(req.params); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const parsedBody = bodySchema.safeParse(req.body); + if (!parsedBody.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedBody.error).toString() + ) + ); + } + + const { orgId } = parsedParams.data; + const { expiresAt } = parsedBody.data; + + // Verify org exists + const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId)); + if (!org) { + return next( + createHttpError(HttpCode.NOT_FOUND, `Organization with ID ${orgId} not found`) + ); + } + + const provisioningKeyId = generateId(15); + const secret = generateIdFromEntropySize(25); + const keyHash = await hashPassword(secret); + const lastChars = secret.slice(-4); + const createdAt = moment().toISOString(); + const provisioningKey = `${provisioningKeyId}.${secret}`; + + await db.insert(newtProvisioningKeys).values({ + provisioningKeyId, + orgId, + keyHash, + lastChars, + createdAt, + expiresAt: expiresAt ?? null + }); + + return response(res, { + data: { + provisioningKeyId, + provisioningKey, + lastChars, + createdAt, + expiresAt: expiresAt ?? null + }, + success: true, + error: false, + message: "Provisioning key created successfully", + status: HttpCode.CREATED + }); + } catch (error) { + logger.error(error); + return next( + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") + ); + } +} diff --git a/server/routers/newt/index.ts b/server/routers/newt/index.ts index 63d1e1068..76207e8d4 100644 --- a/server/routers/newt/index.ts +++ b/server/routers/newt/index.ts @@ -9,3 +9,5 @@ export * from "./handleApplyBlueprintMessage"; export * from "./handleNewtPingMessage"; export * from "./handleNewtDisconnectingMessage"; export * from "./handleConnectionLogMessage"; +export * from "./registerNewt"; +export * from "./createNewtProvisioningKey"; diff --git a/server/routers/newt/registerNewt.ts b/server/routers/newt/registerNewt.ts new file mode 100644 index 000000000..f301b452a --- /dev/null +++ b/server/routers/newt/registerNewt.ts @@ -0,0 +1,240 @@ +import { Request, Response, NextFunction } from "express"; +import { z } from "zod"; +import { db } from "@server/db"; +import { + newtProvisioningKeys, + newts, + orgs, + roles, + roleSites, + sites +} from "@server/db"; +import response from "@server/lib/response"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { eq, and } from "drizzle-orm"; +import { fromError } from "zod-validation-error"; +import { verifyPassword, hashPassword } from "@server/auth/password"; +import { + generateId, + generateIdFromEntropySize +} from "@server/auth/sessions/app"; +import { getUniqueSiteName } from "@server/db/names"; +import moment from "moment"; +import { build } from "@server/build"; +import { usageService } from "@server/lib/billing/usageService"; +import { FeatureId } from "@server/lib/billing"; + +const bodySchema = z.object({ + provisioningKey: z.string().nonempty() +}); + +export type RegisterNewtBody = z.infer; + +export type RegisterNewtResponse = { + newtId: string; + secret: string; +}; + +export async function registerNewt( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedBody = bodySchema.safeParse(req.body); + if (!parsedBody.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedBody.error).toString() + ) + ); + } + + const { provisioningKey } = parsedBody.data; + + // Keys are in the format "id.secret" + const dotIndex = provisioningKey.indexOf("."); + if (dotIndex === -1) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + "Invalid provisioning key format" + ) + ); + } + + const provisioningKeyId = provisioningKey.substring(0, dotIndex); + const provisioningKeySecret = provisioningKey.substring(dotIndex + 1); + + // Look up the provisioning key by ID + const [keyRecord] = await db + .select() + .from(newtProvisioningKeys) + .where( + eq(newtProvisioningKeys.provisioningKeyId, provisioningKeyId) + ) + .limit(1); + + if (!keyRecord) { + return next( + createHttpError(HttpCode.UNAUTHORIZED, "Invalid provisioning key") + ); + } + + // Verify the secret + const validSecret = await verifyPassword( + provisioningKeySecret, + keyRecord.keyHash + ); + if (!validSecret) { + return next( + createHttpError(HttpCode.UNAUTHORIZED, "Invalid provisioning key") + ); + } + + // Check if key has already been used + if (keyRecord.siteId !== null) { + return next( + createHttpError( + HttpCode.CONFLICT, + "Provisioning key has already been used" + ) + ); + } + + // Check expiry + if (keyRecord.expiresAt !== null && keyRecord.expiresAt < Date.now()) { + return next( + createHttpError(HttpCode.GONE, "Provisioning key has expired") + ); + } + + const { orgId } = keyRecord; + + // Verify the org exists + const [org] = await db + .select() + .from(orgs) + .where(eq(orgs.orgId, orgId)); + if (!org) { + return next( + createHttpError(HttpCode.NOT_FOUND, "Organization not found") + ); + } + + // SaaS billing check + if (build == "saas") { + const usage = await usageService.getUsage(orgId, FeatureId.SITES); + if (!usage) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + "No usage data found for this organization" + ) + ); + } + const rejectSites = await usageService.checkLimitSet( + orgId, + FeatureId.SITES, + { + ...usage, + instantaneousValue: (usage.instantaneousValue || 0) + 1 + } + ); + if (rejectSites) { + return next( + createHttpError( + HttpCode.FORBIDDEN, + "Site limit exceeded. Please upgrade your plan." + ) + ); + } + } + + const niceId = await getUniqueSiteName(orgId); + const newtId = generateId(15); + const newtSecret = generateIdFromEntropySize(25); + const secretHash = await hashPassword(newtSecret); + + let newSiteId: number | undefined; + + await db.transaction(async (trx) => { + // Create the site (type "newt", name = niceId) + const [newSite] = await trx + .insert(sites) + .values({ + orgId, + name: niceId, + niceId, + type: "newt", + dockerSocketEnabled: true + }) + .returning(); + + newSiteId = newSite.siteId; + + // Grant admin role access to the new site + const [adminRole] = await trx + .select() + .from(roles) + .where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId))) + .limit(1); + + if (!adminRole) { + throw new Error(`Admin role not found for org ${orgId}`); + } + + await trx.insert(roleSites).values({ + roleId: adminRole.roleId, + siteId: newSite.siteId + }); + + // Create the newt for this site + await trx.insert(newts).values({ + newtId, + secretHash, + siteId: newSite.siteId, + dateCreated: moment().toISOString() + }); + + // Mark the provisioning key as used + await trx + .update(newtProvisioningKeys) + .set({ siteId: newSite.siteId }) + .where( + eq( + newtProvisioningKeys.provisioningKeyId, + provisioningKeyId + ) + ); + + await usageService.add(orgId, FeatureId.SITES, 1, trx); + }); + + logger.info( + `Provisioned new site (ID: ${newSiteId}) and newt (ID: ${newtId}) for org ${orgId} via provisioning key ${provisioningKeyId}` + ); + + return response(res, { + data: { + newtId, + secret: newtSecret + }, + success: true, + error: false, + message: "Newt registered successfully", + status: HttpCode.CREATED + }); + } catch (error) { + logger.error(error); + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "An error occurred" + ) + ); + } +} \ No newline at end of file From 0b5b6ed5a37c49c311fc2dcb520aba3e21a65e47 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 24 Mar 2026 18:26:10 -0700 Subject: [PATCH 11/24] Adjust register endpoint --- server/routers/external.ts | 7 +- .../routers/newt/createNewtProvisioningKey.ts | 108 ------------------ server/routers/newt/index.ts | 1 - server/routers/newt/registerNewt.ts | 71 ++++++------ .../createSiteProvisioningKey.ts | 3 + 5 files changed, 42 insertions(+), 148 deletions(-) delete mode 100644 server/routers/newt/createNewtProvisioningKey.ts diff --git a/server/routers/external.ts b/server/routers/external.ts index a658fa811..297cb2894 100644 --- a/server/routers/external.ts +++ b/server/routers/external.ts @@ -105,12 +105,7 @@ authenticated.put( site.createSite ); -authenticated.put( - "/org/:orgId/newt/provisioning-key", - verifyOrgAccess, - verifyUserHasAction(ActionsEnum.createSite), - newt.createNewtProvisioningKey -); + authenticated.get( "/org/:orgId/sites", verifyOrgAccess, diff --git a/server/routers/newt/createNewtProvisioningKey.ts b/server/routers/newt/createNewtProvisioningKey.ts deleted file mode 100644 index 2af4d166d..000000000 --- a/server/routers/newt/createNewtProvisioningKey.ts +++ /dev/null @@ -1,108 +0,0 @@ -import { Request, Response, NextFunction } from "express"; -import { z } from "zod"; -import { db } from "@server/db"; -import { newtProvisioningKeys, orgs } from "@server/db"; -import response from "@server/lib/response"; -import HttpCode from "@server/types/HttpCode"; -import createHttpError from "http-errors"; -import logger from "@server/logger"; -import { eq } from "drizzle-orm"; -import { fromError } from "zod-validation-error"; -import { - generateId, - generateIdFromEntropySize -} from "@server/auth/sessions/app"; -import { hashPassword } from "@server/auth/password"; -import moment from "moment"; - -const paramsSchema = z.object({ - orgId: z.string().nonempty() -}); - -const bodySchema = z.object({ - expiresAt: z.number().int().positive().optional() // optional Unix timestamp (ms) -}); - -export type CreateNewtProvisioningKeyBody = z.infer; - -export type CreateNewtProvisioningKeyResponse = { - provisioningKeyId: string; - provisioningKey: string; // returned only once: "id.secret" - lastChars: string; - createdAt: string; - expiresAt: number | null; -}; - -export async function createNewtProvisioningKey( - req: Request, - res: Response, - next: NextFunction -): Promise { - try { - const parsedParams = paramsSchema.safeParse(req.params); - if (!parsedParams.success) { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - fromError(parsedParams.error).toString() - ) - ); - } - - const parsedBody = bodySchema.safeParse(req.body); - if (!parsedBody.success) { - return next( - createHttpError( - HttpCode.BAD_REQUEST, - fromError(parsedBody.error).toString() - ) - ); - } - - const { orgId } = parsedParams.data; - const { expiresAt } = parsedBody.data; - - // Verify org exists - const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId)); - if (!org) { - return next( - createHttpError(HttpCode.NOT_FOUND, `Organization with ID ${orgId} not found`) - ); - } - - const provisioningKeyId = generateId(15); - const secret = generateIdFromEntropySize(25); - const keyHash = await hashPassword(secret); - const lastChars = secret.slice(-4); - const createdAt = moment().toISOString(); - const provisioningKey = `${provisioningKeyId}.${secret}`; - - await db.insert(newtProvisioningKeys).values({ - provisioningKeyId, - orgId, - keyHash, - lastChars, - createdAt, - expiresAt: expiresAt ?? null - }); - - return response(res, { - data: { - provisioningKeyId, - provisioningKey, - lastChars, - createdAt, - expiresAt: expiresAt ?? null - }, - success: true, - error: false, - message: "Provisioning key created successfully", - status: HttpCode.CREATED - }); - } catch (error) { - logger.error(error); - return next( - createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") - ); - } -} diff --git a/server/routers/newt/index.ts b/server/routers/newt/index.ts index 76207e8d4..33b5caf7c 100644 --- a/server/routers/newt/index.ts +++ b/server/routers/newt/index.ts @@ -10,4 +10,3 @@ export * from "./handleNewtPingMessage"; export * from "./handleNewtDisconnectingMessage"; export * from "./handleConnectionLogMessage"; export * from "./registerNewt"; -export * from "./createNewtProvisioningKey"; diff --git a/server/routers/newt/registerNewt.ts b/server/routers/newt/registerNewt.ts index f301b452a..b999eb35c 100644 --- a/server/routers/newt/registerNewt.ts +++ b/server/routers/newt/registerNewt.ts @@ -2,7 +2,8 @@ import { Request, Response, NextFunction } from "express"; import { z } from "zod"; import { db } from "@server/db"; import { - newtProvisioningKeys, + siteProvisioningKeys, + siteProvisioningKeyOrg, newts, orgs, roles, @@ -55,7 +56,7 @@ export async function registerNewt( const { provisioningKey } = parsedBody.data; - // Keys are in the format "id.secret" + // Keys are in the format "siteProvisioningKeyId.secret" const dotIndex = provisioningKey.indexOf("."); if (dotIndex === -1) { return next( @@ -69,46 +70,51 @@ export async function registerNewt( const provisioningKeyId = provisioningKey.substring(0, dotIndex); const provisioningKeySecret = provisioningKey.substring(dotIndex + 1); - // Look up the provisioning key by ID + // Look up the provisioning key by ID, joining to get the orgId const [keyRecord] = await db - .select() - .from(newtProvisioningKeys) + .select({ + siteProvisioningKeyId: + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyHash: + siteProvisioningKeys.siteProvisioningKeyHash, + orgId: siteProvisioningKeyOrg.orgId + }) + .from(siteProvisioningKeys) + .innerJoin( + siteProvisioningKeyOrg, + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyOrg.siteProvisioningKeyId + ) + ) .where( - eq(newtProvisioningKeys.provisioningKeyId, provisioningKeyId) + eq( + siteProvisioningKeys.siteProvisioningKeyId, + provisioningKeyId + ) ) .limit(1); if (!keyRecord) { - return next( - createHttpError(HttpCode.UNAUTHORIZED, "Invalid provisioning key") - ); - } - - // Verify the secret - const validSecret = await verifyPassword( - provisioningKeySecret, - keyRecord.keyHash - ); - if (!validSecret) { - return next( - createHttpError(HttpCode.UNAUTHORIZED, "Invalid provisioning key") - ); - } - - // Check if key has already been used - if (keyRecord.siteId !== null) { return next( createHttpError( - HttpCode.CONFLICT, - "Provisioning key has already been used" + HttpCode.UNAUTHORIZED, + "Invalid provisioning key" ) ); } - // Check expiry - if (keyRecord.expiresAt !== null && keyRecord.expiresAt < Date.now()) { + // Verify the secret portion against the stored hash + const validSecret = await verifyPassword( + provisioningKeySecret, + keyRecord.siteProvisioningKeyHash + ); + if (!validSecret) { return next( - createHttpError(HttpCode.GONE, "Provisioning key has expired") + createHttpError( + HttpCode.UNAUTHORIZED, + "Invalid provisioning key" + ) ); } @@ -200,13 +206,12 @@ export async function registerNewt( dateCreated: moment().toISOString() }); - // Mark the provisioning key as used + // Consume the provisioning key — cascade removes siteProvisioningKeyOrg await trx - .update(newtProvisioningKeys) - .set({ siteId: newSite.siteId }) + .delete(siteProvisioningKeys) .where( eq( - newtProvisioningKeys.provisioningKeyId, + siteProvisioningKeys.siteProvisioningKeyId, provisioningKeyId ) ); diff --git a/server/routers/siteProvisioning/createSiteProvisioningKey.ts b/server/routers/siteProvisioning/createSiteProvisioningKey.ts index 9bb298966..a10df65a6 100644 --- a/server/routers/siteProvisioning/createSiteProvisioningKey.ts +++ b/server/routers/siteProvisioning/createSiteProvisioningKey.ts @@ -28,6 +28,7 @@ export type CreateSiteProvisioningKeyResponse = { orgId: string; name: string; siteProvisioningKey: string; + provisioningKey: string; // combined "siteProvisioningKeyId.siteProvisioningKey" — put this in your newt config lastChars: string; createdAt: string; }; @@ -65,6 +66,7 @@ export async function createSiteProvisioningKey( const siteProvisioningKeyHash = await hashPassword(siteProvisioningKey); const lastChars = siteProvisioningKey.slice(-4); const createdAt = moment().toISOString(); + const provisioningKey = `${siteProvisioningKeyId}.${siteProvisioningKey}`; await db.transaction(async (trx) => { await trx.insert(siteProvisioningKeys).values({ @@ -88,6 +90,7 @@ export async function createSiteProvisioningKey( orgId, name, siteProvisioningKey, + provisioningKey, lastChars, createdAt }, From 3525b367b332eedcf689b0700ed47c6108120d11 Mon Sep 17 00:00:00 2001 From: miloschwartz Date: Tue, 24 Mar 2026 20:27:15 -0700 Subject: [PATCH 12/24] move to private routes --- messages/en-US.json | 16 + server/auth/actions.ts | 1 + server/db/pg/schema/privateSchema.ts | 6 +- server/db/sqlite/schema/privateSchema.ts | 6 +- server/lib/billing/tierMatrix.ts | 6 +- .../routers/billing/featureLifecycle.ts | 59 +++- server/private/routers/external.ts | 46 ++- .../createSiteProvisioningKey.ts | 67 +++- .../deleteSiteProvisioningKey.ts | 13 + .../private/routers/siteProvisioning/index.ts | 17 + .../listSiteProvisioningKeys.ts | 27 +- .../updateSiteProvisioningKey.ts | 199 ++++++++++++ server/routers/external.ts | 29 +- server/routers/siteProvisioning/index.ts | 3 - server/routers/siteProvisioning/types.ts | 41 +++ .../settings/provisioning/create/page.tsx | 10 - .../[orgId]/settings/provisioning/page.tsx | 14 +- .../CreateSiteProvisioningKeyCredenza.tsx | 162 +++++++++- .../EditSiteProvisioningKeyCredenza.tsx | 297 ++++++++++++++++++ src/components/LayoutMobileMenu.tsx | 2 +- src/components/LayoutSidebar.tsx | 18 +- src/components/ProductUpdates.tsx | 8 +- src/components/SiteProvisioningKeysTable.tsx | 106 ++++++- src/components/ui/data-table.tsx | 4 +- 24 files changed, 1054 insertions(+), 103 deletions(-) rename server/{ => private}/routers/siteProvisioning/createSiteProvisioningKey.ts (62%) rename server/{ => private}/routers/siteProvisioning/deleteSiteProvisioningKey.ts (90%) create mode 100644 server/private/routers/siteProvisioning/index.ts rename server/{ => private}/routers/siteProvisioning/listSiteProvisioningKeys.ts (80%) create mode 100644 server/private/routers/siteProvisioning/updateSiteProvisioningKey.ts delete mode 100644 server/routers/siteProvisioning/index.ts create mode 100644 server/routers/siteProvisioning/types.ts delete mode 100644 src/app/[orgId]/settings/provisioning/create/page.tsx create mode 100644 src/components/EditSiteProvisioningKeyCredenza.tsx diff --git a/messages/en-US.json b/messages/en-US.json index 1785f0491..a7d16f30f 100644 --- a/messages/en-US.json +++ b/messages/en-US.json @@ -342,6 +342,22 @@ "provisioningKeysSaveDescription": "You will only be able to see this once. Copy it to a secure place.", "provisioningKeysErrorCreate": "Error creating provisioning key", "provisioningKeysList": "New provisioning key", + "provisioningKeysMaxBatchSize": "Max batch size", + "provisioningKeysUnlimitedBatchSize": "Unlimited batch size (no limit)", + "provisioningKeysMaxBatchUnlimited": "Unlimited", + "provisioningKeysMaxBatchSizeInvalid": "Enter a valid max batch size (1–1,000,000).", + "provisioningKeysValidUntil": "Valid until", + "provisioningKeysValidUntilHint": "Leave empty for no expiration.", + "provisioningKeysValidUntilInvalid": "Enter a valid date and time.", + "provisioningKeysNumUsed": "Times used", + "provisioningKeysLastUsed": "Last used", + "provisioningKeysNoExpiry": "No expiration", + "provisioningKeysNeverUsed": "Never", + "provisioningKeysEdit": "Edit Provisioning Key", + "provisioningKeysEditDescription": "Update the max batch size and expiration time for this key.", + "provisioningKeysUpdateError": "Error updating provisioning key", + "provisioningKeysUpdated": "Provisioning key updated", + "provisioningKeysUpdatedDescription": "Your changes have been saved.", "apiKeysSettings": "{apiKeyName} Settings", "userTitle": "Manage All Users", "userDescription": "View and manage all users in the system", diff --git a/server/auth/actions.ts b/server/auth/actions.ts index 6cdc4fa0a..20f1fe795 100644 --- a/server/auth/actions.ts +++ b/server/auth/actions.ts @@ -111,6 +111,7 @@ export enum ActionsEnum { getApiKey = "getApiKey", createSiteProvisioningKey = "createSiteProvisioningKey", listSiteProvisioningKeys = "listSiteProvisioningKeys", + updateSiteProvisioningKey = "updateSiteProvisioningKey", deleteSiteProvisioningKey = "deleteSiteProvisioningKey", getCertificate = "getCertificate", restartCertificate = "restartCertificate", diff --git a/server/db/pg/schema/privateSchema.ts b/server/db/pg/schema/privateSchema.ts index b1dc98253..bb1e866c4 100644 --- a/server/db/pg/schema/privateSchema.ts +++ b/server/db/pg/schema/privateSchema.ts @@ -387,7 +387,11 @@ export const siteProvisioningKeys = pgTable("siteProvisioningKeys", { name: varchar("name", { length: 255 }).notNull(), siteProvisioningKeyHash: text("siteProvisioningKeyHash").notNull(), lastChars: varchar("lastChars", { length: 4 }).notNull(), - createdAt: varchar("dateCreated", { length: 255 }).notNull() + createdAt: varchar("dateCreated", { length: 255 }).notNull(), + lastUsed: varchar("lastUsed", { length: 255 }), + maxBatchSize: integer("maxBatchSize"), // null = no limit + numUsed: integer("numUsed").notNull().default(0), + validUntil: varchar("validUntil", { length: 255 }) }); export const siteProvisioningKeyOrg = pgTable( diff --git a/server/db/sqlite/schema/privateSchema.ts b/server/db/sqlite/schema/privateSchema.ts index e78343d6d..5913497b3 100644 --- a/server/db/sqlite/schema/privateSchema.ts +++ b/server/db/sqlite/schema/privateSchema.ts @@ -371,7 +371,11 @@ export const siteProvisioningKeys = sqliteTable("siteProvisioningKeys", { name: text("name").notNull(), siteProvisioningKeyHash: text("siteProvisioningKeyHash").notNull(), lastChars: text("lastChars").notNull(), - createdAt: text("dateCreated").notNull() + createdAt: text("dateCreated").notNull(), + lastUsed: text("lastUsed"), + maxBatchSize: integer("maxBatchSize"), // null = no limit + numUsed: integer("numUsed").notNull().default(0), + validUntil: text("validUntil") }); export const siteProvisioningKeyOrg = sqliteTable( diff --git a/server/lib/billing/tierMatrix.ts b/server/lib/billing/tierMatrix.ts index f8a0cd2f5..8c41be5bf 100644 --- a/server/lib/billing/tierMatrix.ts +++ b/server/lib/billing/tierMatrix.ts @@ -16,7 +16,8 @@ export enum TierFeature { SessionDurationPolicies = "sessionDurationPolicies", // handle downgrade by setting to default duration PasswordExpirationPolicies = "passwordExpirationPolicies", // handle downgrade by setting to default duration AutoProvisioning = "autoProvisioning", // handle downgrade by disabling auto provisioning - SshPam = "sshPam" + SshPam = "sshPam", + SiteProvisioningKeys = "siteProvisioningKeys" // handle downgrade by revoking keys if needed } export const tierMatrix: Record = { @@ -50,5 +51,6 @@ export const tierMatrix: Record = { "enterprise" ], [TierFeature.AutoProvisioning]: ["tier1", "tier3", "enterprise"], - [TierFeature.SshPam]: ["tier1", "tier3", "enterprise"] + [TierFeature.SshPam]: ["tier1", "tier3", "enterprise"], + [TierFeature.SiteProvisioningKeys]: ["enterprise"] }; diff --git a/server/private/routers/billing/featureLifecycle.ts b/server/private/routers/billing/featureLifecycle.ts index 9536a87f0..32e81784b 100644 --- a/server/private/routers/billing/featureLifecycle.ts +++ b/server/private/routers/billing/featureLifecycle.ts @@ -26,9 +26,11 @@ import { orgs, resources, roles, + siteProvisioningKeyOrg, + siteProvisioningKeys, siteResources } from "@server/db"; -import { eq } from "drizzle-orm"; +import { and, eq } from "drizzle-orm"; /** * Get the maximum allowed retention days for a given tier @@ -291,6 +293,10 @@ async function disableFeature( await disableSshPam(orgId); break; + case TierFeature.SiteProvisioningKeys: + await disableSiteProvisioningKeys(orgId); + break; + default: logger.warn( `Unknown feature ${feature} for org ${orgId}, skipping` @@ -326,6 +332,57 @@ async function disableSshPam(orgId: string): Promise { ); } +async function disableSiteProvisioningKeys(orgId: string): Promise { + const rows = await db + .select({ + siteProvisioningKeyId: + siteProvisioningKeyOrg.siteProvisioningKeyId + }) + .from(siteProvisioningKeyOrg) + .where(eq(siteProvisioningKeyOrg.orgId, orgId)); + + for (const { siteProvisioningKeyId } of rows) { + await db.transaction(async (trx) => { + await trx + .delete(siteProvisioningKeyOrg) + .where( + and( + eq( + siteProvisioningKeyOrg.siteProvisioningKeyId, + siteProvisioningKeyId + ), + eq(siteProvisioningKeyOrg.orgId, orgId) + ) + ); + + const remaining = await trx + .select() + .from(siteProvisioningKeyOrg) + .where( + eq( + siteProvisioningKeyOrg.siteProvisioningKeyId, + siteProvisioningKeyId + ) + ); + + if (remaining.length === 0) { + await trx + .delete(siteProvisioningKeys) + .where( + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyId + ) + ); + } + }); + } + + logger.info( + `Removed site provisioning keys for org ${orgId} after tier downgrade` + ); +} + async function disableLoginPageBranding(orgId: string): Promise { const [existingBranding] = await db .select() diff --git a/server/private/routers/external.ts b/server/private/routers/external.ts index f06ad4517..6cac25e20 100644 --- a/server/private/routers/external.ts +++ b/server/private/routers/external.ts @@ -26,6 +26,7 @@ import * as misc from "#private/routers/misc"; import * as reKey from "#private/routers/re-key"; import * as approval from "#private/routers/approvals"; import * as ssh from "#private/routers/ssh"; +import * as siteProvisioning from "#private/routers/siteProvisioning"; import { verifyOrgAccess, @@ -33,7 +34,8 @@ import { verifyUserIsServerAdmin, verifySiteAccess, verifyClientAccess, - verifyLimits + verifyLimits, + verifySiteProvisioningKeyAccess } from "@server/middlewares"; import { ActionsEnum } from "@server/auth/actions"; import { @@ -537,3 +539,45 @@ authenticated.post( // logActionAudit(ActionsEnum.signSshKey), // it is handled inside of the function below so we can include more metadata ssh.signSshKey ); + +authenticated.put( + "/org/:orgId/site-provisioning-key", + verifyValidLicense, + verifyValidSubscription(tierMatrix.siteProvisioningKeys), + verifyOrgAccess, + verifyLimits, + verifyUserHasAction(ActionsEnum.createSiteProvisioningKey), + logActionAudit(ActionsEnum.createSiteProvisioningKey), + siteProvisioning.createSiteProvisioningKey +); + +authenticated.get( + "/org/:orgId/site-provisioning-keys", + verifyValidLicense, + verifyValidSubscription(tierMatrix.siteProvisioningKeys), + verifyOrgAccess, + verifyUserHasAction(ActionsEnum.listSiteProvisioningKeys), + siteProvisioning.listSiteProvisioningKeys +); + +authenticated.delete( + "/org/:orgId/site-provisioning-key/:siteProvisioningKeyId", + verifyValidLicense, + verifyValidSubscription(tierMatrix.siteProvisioningKeys), + verifyOrgAccess, + verifySiteProvisioningKeyAccess, + verifyUserHasAction(ActionsEnum.deleteSiteProvisioningKey), + logActionAudit(ActionsEnum.deleteSiteProvisioningKey), + siteProvisioning.deleteSiteProvisioningKey +); + +authenticated.patch( + "/org/:orgId/site-provisioning-key/:siteProvisioningKeyId", + verifyValidLicense, + verifyValidSubscription(tierMatrix.siteProvisioningKeys), + verifyOrgAccess, + verifySiteProvisioningKeyAccess, + verifyUserHasAction(ActionsEnum.updateSiteProvisioningKey), + logActionAudit(ActionsEnum.updateSiteProvisioningKey), + siteProvisioning.updateSiteProvisioningKey +); diff --git a/server/routers/siteProvisioning/createSiteProvisioningKey.ts b/server/private/routers/siteProvisioning/createSiteProvisioningKey.ts similarity index 62% rename from server/routers/siteProvisioning/createSiteProvisioningKey.ts rename to server/private/routers/siteProvisioning/createSiteProvisioningKey.ts index 9bb298966..45d980810 100644 --- a/server/routers/siteProvisioning/createSiteProvisioningKey.ts +++ b/server/private/routers/siteProvisioning/createSiteProvisioningKey.ts @@ -1,3 +1,16 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + import { NextFunction, Request, Response } from "express"; import { db, siteProvisioningKeyOrg, siteProvisioningKeys } from "@server/db"; import HttpCode from "@server/types/HttpCode"; @@ -12,26 +25,37 @@ import { } from "@server/auth/sessions/app"; import logger from "@server/logger"; import { hashPassword } from "@server/auth/password"; +import type { CreateSiteProvisioningKeyResponse } from "@server/routers/siteProvisioning/types"; const paramsSchema = z.object({ orgId: z.string().nonempty() }); -const bodySchema = z.strictObject({ - name: z.string().min(1).max(255) -}); +const bodySchema = z + .strictObject({ + name: z.string().min(1).max(255), + maxBatchSize: z.union([ + z.null(), + z.coerce.number().int().positive().max(1_000_000) + ]), + validUntil: z.string().max(255).optional() + }) + .superRefine((data, ctx) => { + const v = data.validUntil; + if (v == null || v.trim() === "") { + return; + } + if (Number.isNaN(Date.parse(v))) { + ctx.addIssue({ + code: "custom", + message: "Invalid validUntil", + path: ["validUntil"] + }); + } + }); export type CreateSiteProvisioningKeyBody = z.infer; -export type CreateSiteProvisioningKeyResponse = { - siteProvisioningKeyId: string; - orgId: string; - name: string; - siteProvisioningKey: string; - lastChars: string; - createdAt: string; -}; - export async function createSiteProvisioningKey( req: Request, res: Response, @@ -58,7 +82,12 @@ export async function createSiteProvisioningKey( } const { orgId } = parsedParams.data; - const { name } = parsedBody.data; + const { name, maxBatchSize } = parsedBody.data; + const vuRaw = parsedBody.data.validUntil; + const validUntil = + vuRaw == null || vuRaw.trim() === "" + ? null + : new Date(Date.parse(vuRaw)).toISOString(); const siteProvisioningKeyId = `spk-${generateId(15)}`; const siteProvisioningKey = generateIdFromEntropySize(25); @@ -72,7 +101,11 @@ export async function createSiteProvisioningKey( name, siteProvisioningKeyHash, createdAt, - lastChars + lastChars, + lastUsed: null, + maxBatchSize, + numUsed: 0, + validUntil }); await trx.insert(siteProvisioningKeyOrg).values({ @@ -89,7 +122,11 @@ export async function createSiteProvisioningKey( name, siteProvisioningKey, lastChars, - createdAt + createdAt, + lastUsed: null, + maxBatchSize, + numUsed: 0, + validUntil }, success: true, error: false, diff --git a/server/routers/siteProvisioning/deleteSiteProvisioningKey.ts b/server/private/routers/siteProvisioning/deleteSiteProvisioningKey.ts similarity index 90% rename from server/routers/siteProvisioning/deleteSiteProvisioningKey.ts rename to server/private/routers/siteProvisioning/deleteSiteProvisioningKey.ts index d1da01d97..fc8b05e60 100644 --- a/server/routers/siteProvisioning/deleteSiteProvisioningKey.ts +++ b/server/private/routers/siteProvisioning/deleteSiteProvisioningKey.ts @@ -1,3 +1,16 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + import { Request, Response, NextFunction } from "express"; import { z } from "zod"; import { diff --git a/server/private/routers/siteProvisioning/index.ts b/server/private/routers/siteProvisioning/index.ts new file mode 100644 index 000000000..d143274f6 --- /dev/null +++ b/server/private/routers/siteProvisioning/index.ts @@ -0,0 +1,17 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + +export * from "./createSiteProvisioningKey"; +export * from "./listSiteProvisioningKeys"; +export * from "./deleteSiteProvisioningKey"; +export * from "./updateSiteProvisioningKey"; diff --git a/server/routers/siteProvisioning/listSiteProvisioningKeys.ts b/server/private/routers/siteProvisioning/listSiteProvisioningKeys.ts similarity index 80% rename from server/routers/siteProvisioning/listSiteProvisioningKeys.ts rename to server/private/routers/siteProvisioning/listSiteProvisioningKeys.ts index 65360625c..5f7531a2c 100644 --- a/server/routers/siteProvisioning/listSiteProvisioningKeys.ts +++ b/server/private/routers/siteProvisioning/listSiteProvisioningKeys.ts @@ -1,3 +1,16 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + import { db, siteProvisioningKeyOrg, @@ -11,6 +24,7 @@ import createHttpError from "http-errors"; import { z } from "zod"; import { fromError } from "zod-validation-error"; import { eq } from "drizzle-orm"; +import type { ListSiteProvisioningKeysResponse } from "@server/routers/siteProvisioning/types"; const paramsSchema = z.object({ orgId: z.string().nonempty() @@ -39,7 +53,11 @@ function querySiteProvisioningKeys(orgId: string) { orgId: siteProvisioningKeyOrg.orgId, lastChars: siteProvisioningKeys.lastChars, createdAt: siteProvisioningKeys.createdAt, - name: siteProvisioningKeys.name + name: siteProvisioningKeys.name, + lastUsed: siteProvisioningKeys.lastUsed, + maxBatchSize: siteProvisioningKeys.maxBatchSize, + numUsed: siteProvisioningKeys.numUsed, + validUntil: siteProvisioningKeys.validUntil }) .from(siteProvisioningKeyOrg) .innerJoin( @@ -52,13 +70,6 @@ function querySiteProvisioningKeys(orgId: string) { .where(eq(siteProvisioningKeyOrg.orgId, orgId)); } -export type ListSiteProvisioningKeysResponse = { - siteProvisioningKeys: Awaited< - ReturnType - >; - pagination: { total: number; limit: number; offset: number }; -}; - export async function listSiteProvisioningKeys( req: Request, res: Response, diff --git a/server/private/routers/siteProvisioning/updateSiteProvisioningKey.ts b/server/private/routers/siteProvisioning/updateSiteProvisioningKey.ts new file mode 100644 index 000000000..526d8bfb8 --- /dev/null +++ b/server/private/routers/siteProvisioning/updateSiteProvisioningKey.ts @@ -0,0 +1,199 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + +import { Request, Response, NextFunction } from "express"; +import { z } from "zod"; +import { + db, + siteProvisioningKeyOrg, + siteProvisioningKeys +} from "@server/db"; +import { and, eq } from "drizzle-orm"; +import response from "@server/lib/response"; +import HttpCode from "@server/types/HttpCode"; +import createHttpError from "http-errors"; +import logger from "@server/logger"; +import { fromError } from "zod-validation-error"; +import type { UpdateSiteProvisioningKeyResponse } from "@server/routers/siteProvisioning/types"; + +const paramsSchema = z.object({ + siteProvisioningKeyId: z.string().nonempty(), + orgId: z.string().nonempty() +}); + +const bodySchema = z + .strictObject({ + maxBatchSize: z + .union([ + z.null(), + z.coerce.number().int().positive().max(1_000_000) + ]) + .optional(), + validUntil: z.string().max(255).optional() + }) + .superRefine((data, ctx) => { + if ( + data.maxBatchSize === undefined && + data.validUntil === undefined + ) { + ctx.addIssue({ + code: "custom", + message: "Provide maxBatchSize and/or validUntil", + path: ["maxBatchSize"] + }); + } + const v = data.validUntil; + if (v == null || v.trim() === "") { + return; + } + if (Number.isNaN(Date.parse(v))) { + ctx.addIssue({ + code: "custom", + message: "Invalid validUntil", + path: ["validUntil"] + }); + } + }); + +export type UpdateSiteProvisioningKeyBody = z.infer; + +export async function updateSiteProvisioningKey( + req: Request, + res: Response, + next: NextFunction +): Promise { + try { + const parsedParams = paramsSchema.safeParse(req.params); + if (!parsedParams.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedParams.error).toString() + ) + ); + } + + const parsedBody = bodySchema.safeParse(req.body); + if (!parsedBody.success) { + return next( + createHttpError( + HttpCode.BAD_REQUEST, + fromError(parsedBody.error).toString() + ) + ); + } + + const { siteProvisioningKeyId, orgId } = parsedParams.data; + const body = parsedBody.data; + + const [row] = await db + .select() + .from(siteProvisioningKeys) + .where( + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyId + ) + ) + .innerJoin( + siteProvisioningKeyOrg, + and( + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyOrg.siteProvisioningKeyId + ), + eq(siteProvisioningKeyOrg.orgId, orgId) + ) + ) + .limit(1); + + if (!row) { + return next( + createHttpError( + HttpCode.NOT_FOUND, + `Site provisioning key with ID ${siteProvisioningKeyId} not found` + ) + ); + } + + const setValues: { + maxBatchSize?: number | null; + validUntil?: string | null; + } = {}; + if (body.maxBatchSize !== undefined) { + setValues.maxBatchSize = body.maxBatchSize; + } + if (body.validUntil !== undefined) { + setValues.validUntil = + body.validUntil.trim() === "" + ? null + : new Date(Date.parse(body.validUntil)).toISOString(); + } + + await db + .update(siteProvisioningKeys) + .set(setValues) + .where( + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyId + ) + ); + + const [updated] = await db + .select({ + siteProvisioningKeyId: + siteProvisioningKeys.siteProvisioningKeyId, + name: siteProvisioningKeys.name, + lastChars: siteProvisioningKeys.lastChars, + createdAt: siteProvisioningKeys.createdAt, + lastUsed: siteProvisioningKeys.lastUsed, + maxBatchSize: siteProvisioningKeys.maxBatchSize, + numUsed: siteProvisioningKeys.numUsed, + validUntil: siteProvisioningKeys.validUntil + }) + .from(siteProvisioningKeys) + .where( + eq( + siteProvisioningKeys.siteProvisioningKeyId, + siteProvisioningKeyId + ) + ) + .limit(1); + + if (!updated) { + return next( + createHttpError( + HttpCode.INTERNAL_SERVER_ERROR, + "Failed to load updated site provisioning key" + ) + ); + } + + return response(res, { + data: { + ...updated, + orgId + }, + success: true, + error: false, + message: "Site provisioning key updated successfully", + status: HttpCode.OK + }); + } catch (error) { + logger.error(error); + return next( + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") + ); + } +} diff --git a/server/routers/external.ts b/server/routers/external.ts index 90f208863..45ab58bba 100644 --- a/server/routers/external.ts +++ b/server/routers/external.ts @@ -15,7 +15,6 @@ import * as accessToken from "./accessToken"; import * as idp from "./idp"; import * as blueprints from "./blueprints"; import * as apiKeys from "./apiKeys"; -import * as siteProvisioning from "./siteProvisioning"; import * as logs from "./auditLogs"; import * as newt from "./newt"; import * as olm from "./olm"; @@ -43,8 +42,7 @@ import { verifyUserIsOrgOwner, verifySiteResourceAccess, verifyOlmAccess, - verifyLimits, - verifySiteProvisioningKeyAccess + verifyLimits } from "@server/middlewares"; import { ActionsEnum } from "@server/auth/actions"; import rateLimit, { ipKeyGenerator } from "express-rate-limit"; @@ -988,31 +986,6 @@ authenticated.get( apiKeys.listRootApiKeys ); -authenticated.put( - `/org/:orgId/site-provisioning-key`, - verifyOrgAccess, - verifyLimits, - verifyUserHasAction(ActionsEnum.createSiteProvisioningKey), - logActionAudit(ActionsEnum.createSiteProvisioningKey), - siteProvisioning.createSiteProvisioningKey -); - -authenticated.get( - `/org/:orgId/site-provisioning-keys`, - verifyOrgAccess, - verifyUserHasAction(ActionsEnum.listSiteProvisioningKeys), - siteProvisioning.listSiteProvisioningKeys -); - -authenticated.delete( - `/org/:orgId/site-provisioning-key/:siteProvisioningKeyId`, - verifyOrgAccess, - verifySiteProvisioningKeyAccess, - verifyUserHasAction(ActionsEnum.deleteSiteProvisioningKey), - logActionAudit(ActionsEnum.deleteSiteProvisioningKey), - siteProvisioning.deleteSiteProvisioningKey -); - authenticated.get( `/api-key/:apiKeyId/actions`, verifyUserIsServerAdmin, diff --git a/server/routers/siteProvisioning/index.ts b/server/routers/siteProvisioning/index.ts deleted file mode 100644 index b3f69f100..000000000 --- a/server/routers/siteProvisioning/index.ts +++ /dev/null @@ -1,3 +0,0 @@ -export * from "./createSiteProvisioningKey"; -export * from "./listSiteProvisioningKeys"; -export * from "./deleteSiteProvisioningKey"; diff --git a/server/routers/siteProvisioning/types.ts b/server/routers/siteProvisioning/types.ts new file mode 100644 index 000000000..d06c1fe26 --- /dev/null +++ b/server/routers/siteProvisioning/types.ts @@ -0,0 +1,41 @@ +export type SiteProvisioningKeyListItem = { + siteProvisioningKeyId: string; + orgId: string; + lastChars: string; + createdAt: string; + name: string; + lastUsed: string | null; + maxBatchSize: number | null; + numUsed: number; + validUntil: string | null; +}; + +export type ListSiteProvisioningKeysResponse = { + siteProvisioningKeys: SiteProvisioningKeyListItem[]; + pagination: { total: number; limit: number; offset: number }; +}; + +export type CreateSiteProvisioningKeyResponse = { + siteProvisioningKeyId: string; + orgId: string; + name: string; + siteProvisioningKey: string; + lastChars: string; + createdAt: string; + lastUsed: string | null; + maxBatchSize: number | null; + numUsed: number; + validUntil: string | null; +}; + +export type UpdateSiteProvisioningKeyResponse = { + siteProvisioningKeyId: string; + orgId: string; + name: string; + lastChars: string; + createdAt: string; + lastUsed: string | null; + maxBatchSize: number | null; + numUsed: number; + validUntil: string | null; +}; diff --git a/src/app/[orgId]/settings/provisioning/create/page.tsx b/src/app/[orgId]/settings/provisioning/create/page.tsx deleted file mode 100644 index 98573147a..000000000 --- a/src/app/[orgId]/settings/provisioning/create/page.tsx +++ /dev/null @@ -1,10 +0,0 @@ -import { redirect } from "next/navigation"; - -type PageProps = { - params: Promise<{ orgId: string }>; -}; - -export default async function ProvisioningCreateRedirect(props: PageProps) { - const params = await props.params; - redirect(`/${params.orgId}/settings/provisioning`); -} diff --git a/src/app/[orgId]/settings/provisioning/page.tsx b/src/app/[orgId]/settings/provisioning/page.tsx index f8a30b86f..e8b53104f 100644 --- a/src/app/[orgId]/settings/provisioning/page.tsx +++ b/src/app/[orgId]/settings/provisioning/page.tsx @@ -1,12 +1,14 @@ import { internal } from "@app/lib/api"; import { authCookieHeader } from "@app/lib/api/cookies"; import { AxiosResponse } from "axios"; +import { PaidFeaturesAlert } from "@app/components/PaidFeaturesAlert"; import SettingsSectionTitle from "@app/components/SettingsSectionTitle"; import SiteProvisioningKeysTable, { SiteProvisioningKeyRow } from "../../../../components/SiteProvisioningKeysTable"; -import { ListSiteProvisioningKeysResponse } from "@server/routers/siteProvisioning/listSiteProvisioningKeys"; +import { ListSiteProvisioningKeysResponse } from "@server/routers/siteProvisioning/types"; import { getTranslations } from "next-intl/server"; +import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix"; type ProvisioningPageProps = { params: Promise<{ orgId: string }>; @@ -34,7 +36,11 @@ export default async function ProvisioningPage(props: ProvisioningPageProps) { name: k.name, id: k.siteProvisioningKeyId, key: `${k.siteProvisioningKeyId}••••••••••••••••••••${k.lastChars}`, - createdAt: k.createdAt + createdAt: k.createdAt, + lastUsed: k.lastUsed, + maxBatchSize: k.maxBatchSize, + numUsed: k.numUsed, + validUntil: k.validUntil })); return ( @@ -44,6 +50,10 @@ export default async function ProvisioningPage(props: ProvisioningPageProps) { description={t("provisioningKeysDescription")} /> + + ); diff --git a/src/components/CreateSiteProvisioningKeyCredenza.tsx b/src/components/CreateSiteProvisioningKeyCredenza.tsx index 456731ed6..70c48ff08 100644 --- a/src/components/CreateSiteProvisioningKeyCredenza.tsx +++ b/src/components/CreateSiteProvisioningKeyCredenza.tsx @@ -13,18 +13,20 @@ import { import { Form, FormControl, + FormDescription, FormField, FormItem, FormLabel, FormMessage } from "@app/components/ui/form"; import { Button } from "@app/components/ui/button"; +import { Checkbox } from "@app/components/ui/checkbox"; import { Input } from "@app/components/ui/input"; import { Alert, AlertDescription, AlertTitle } from "@app/components/ui/alert"; import { useEnvContext } from "@app/hooks/useEnvContext"; import { toast } from "@app/hooks/useToast"; import { createApiClient, formatAxiosError } from "@app/lib/api"; -import { CreateSiteProvisioningKeyResponse } from "@server/routers/siteProvisioning/createSiteProvisioningKey"; +import { CreateSiteProvisioningKeyResponse } from "@server/routers/siteProvisioning/types"; import { AxiosResponse } from "axios"; import { InfoIcon } from "lucide-react"; import { useTranslations } from "next-intl"; @@ -55,30 +57,61 @@ export default function CreateSiteProvisioningKeyCredenza({ const [created, setCreated] = useState(null); - const createFormSchema = z.object({ - name: z - .string() - .min(1, { - message: t("nameMin", { len: 1 }) - }) - .max(255, { - message: t("nameMax", { len: 255 }) - }) - }); + const createFormSchema = z + .object({ + name: z + .string() + .min(1, { + message: t("nameMin", { len: 1 }) + }) + .max(255, { + message: t("nameMax", { len: 255 }) + }), + unlimitedBatchSize: z.boolean(), + maxBatchSize: z + .number() + .int() + .min(1, { message: t("provisioningKeysMaxBatchSizeInvalid") }) + .max(1_000_000, { + message: t("provisioningKeysMaxBatchSizeInvalid") + }), + validUntil: z.string().optional() + }) + .superRefine((data, ctx) => { + const v = data.validUntil; + if (v == null || v.trim() === "") { + return; + } + if (Number.isNaN(Date.parse(v))) { + ctx.addIssue({ + code: "custom", + message: t("provisioningKeysValidUntilInvalid"), + path: ["validUntil"] + }); + } + }); type CreateFormValues = z.infer; const form = useForm({ resolver: zodResolver(createFormSchema), defaultValues: { - name: "" + name: "", + unlimitedBatchSize: false, + maxBatchSize: 100, + validUntil: "" } }); useEffect(() => { if (!open) { setCreated(null); - form.reset({ name: "" }); + form.reset({ + name: "", + unlimitedBatchSize: false, + maxBatchSize: 100, + validUntil: "" + }); } }, [open, form]); @@ -88,7 +121,16 @@ export default function CreateSiteProvisioningKeyCredenza({ const res = await api .put< AxiosResponse - >(`/org/${orgId}/site-provisioning-key`, { name: data.name }) + >(`/org/${orgId}/site-provisioning-key`, { + name: data.name, + maxBatchSize: data.unlimitedBatchSize + ? null + : data.maxBatchSize, + validUntil: + data.validUntil == null || data.validUntil.trim() === "" + ? undefined + : data.validUntil + }) .catch((e) => { toast({ variant: "destructive", @@ -110,6 +152,8 @@ export default function CreateSiteProvisioningKeyCredenza({ created && `${created.siteProvisioningKeyId}.${created.siteProvisioningKey}`; + const unlimitedBatchSize = form.watch("unlimitedBatchSize"); + return ( @@ -149,6 +193,96 @@ export default function CreateSiteProvisioningKeyCredenza({ )} /> + ( + + + {t( + "provisioningKeysMaxBatchSize" + )} + + + { + const v = + e.target.value; + field.onChange( + v === "" + ? 100 + : Number(v) + ); + }} + value={field.value} + /> + + + + )} + /> + ( + + + + field.onChange( + c === true + ) + } + /> + + + {t( + "provisioningKeysUnlimitedBatchSize" + )} + + + )} + /> + ( + + + {t( + "provisioningKeysValidUntil" + )} + + + + + + {t( + "provisioningKeysValidUntilHint" + )} + + + + )} + /> )} diff --git a/src/components/EditSiteProvisioningKeyCredenza.tsx b/src/components/EditSiteProvisioningKeyCredenza.tsx new file mode 100644 index 000000000..9603374d5 --- /dev/null +++ b/src/components/EditSiteProvisioningKeyCredenza.tsx @@ -0,0 +1,297 @@ +"use client"; + +import { + Credenza, + CredenzaBody, + CredenzaClose, + CredenzaContent, + CredenzaDescription, + CredenzaFooter, + CredenzaHeader, + CredenzaTitle +} from "@app/components/Credenza"; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage +} from "@app/components/ui/form"; +import { Button } from "@app/components/ui/button"; +import { Checkbox } from "@app/components/ui/checkbox"; +import { Input } from "@app/components/ui/input"; +import { useEnvContext } from "@app/hooks/useEnvContext"; +import { toast } from "@app/hooks/useToast"; +import { createApiClient, formatAxiosError } from "@app/lib/api"; +import { UpdateSiteProvisioningKeyResponse } from "@server/routers/siteProvisioning/types"; +import { AxiosResponse } from "axios"; +import { useTranslations } from "next-intl"; +import { useRouter } from "next/navigation"; +import { useEffect, useState } from "react"; +import { useForm } from "react-hook-form"; +import { z } from "zod"; +import { zodResolver } from "@hookform/resolvers/zod"; +import moment from "moment"; + +const FORM_ID = "edit-site-provisioning-key-form"; + +export type EditableSiteProvisioningKey = { + id: string; + name: string; + maxBatchSize: number | null; + validUntil: string | null; +}; + +type EditSiteProvisioningKeyCredenzaProps = { + open: boolean; + setOpen: (open: boolean) => void; + orgId: string; + provisioningKey: EditableSiteProvisioningKey | null; +}; + +export default function EditSiteProvisioningKeyCredenza({ + open, + setOpen, + orgId, + provisioningKey +}: EditSiteProvisioningKeyCredenzaProps) { + const t = useTranslations(); + const router = useRouter(); + const api = createApiClient(useEnvContext()); + const [loading, setLoading] = useState(false); + + const editFormSchema = z + .object({ + name: z.string(), + unlimitedBatchSize: z.boolean(), + maxBatchSize: z + .number() + .int() + .min(1, { message: t("provisioningKeysMaxBatchSizeInvalid") }) + .max(1_000_000, { + message: t("provisioningKeysMaxBatchSizeInvalid") + }), + validUntil: z.string().optional() + }) + .superRefine((data, ctx) => { + const v = data.validUntil; + if (v == null || v.trim() === "") { + return; + } + if (Number.isNaN(Date.parse(v))) { + ctx.addIssue({ + code: "custom", + message: t("provisioningKeysValidUntilInvalid"), + path: ["validUntil"] + }); + } + }); + + type EditFormValues = z.infer; + + const form = useForm({ + resolver: zodResolver(editFormSchema), + defaultValues: { + name: "", + unlimitedBatchSize: false, + maxBatchSize: 100, + validUntil: "" + } + }); + + useEffect(() => { + if (!open || !provisioningKey) { + return; + } + form.reset({ + name: provisioningKey.name, + unlimitedBatchSize: provisioningKey.maxBatchSize == null, + maxBatchSize: provisioningKey.maxBatchSize ?? 100, + validUntil: provisioningKey.validUntil + ? moment(provisioningKey.validUntil).format("YYYY-MM-DDTHH:mm") + : "" + }); + }, [open, provisioningKey, form]); + + async function onSubmit(data: EditFormValues) { + if (!provisioningKey) { + return; + } + setLoading(true); + try { + const res = await api + .patch< + AxiosResponse + >( + `/org/${orgId}/site-provisioning-key/${provisioningKey.id}`, + { + maxBatchSize: data.unlimitedBatchSize + ? null + : data.maxBatchSize, + validUntil: + data.validUntil == null || + data.validUntil.trim() === "" + ? "" + : data.validUntil + } + ) + .catch((e) => { + toast({ + variant: "destructive", + title: t("provisioningKeysUpdateError"), + description: formatAxiosError(e) + }); + }); + + if (res && res.status === 200) { + toast({ + title: t("provisioningKeysUpdated"), + description: t("provisioningKeysUpdatedDescription") + }); + setOpen(false); + router.refresh(); + } + } finally { + setLoading(false); + } + } + + const unlimitedBatchSize = form.watch("unlimitedBatchSize"); + + if (!provisioningKey) { + return null; + } + + return ( + + + + {t("provisioningKeysEdit")} + + {t("provisioningKeysEditDescription")} + + + +
+ + ( + + {t("name")} + + + + + )} + /> + ( + + + {t("provisioningKeysMaxBatchSize")} + + + { + const v = e.target.value; + field.onChange( + v === "" + ? 100 + : Number(v) + ); + }} + value={field.value} + /> + + + + )} + /> + ( + + + + field.onChange(c === true) + } + /> + + + {t( + "provisioningKeysUnlimitedBatchSize" + )} + + + )} + /> + ( + + + {t("provisioningKeysValidUntil")} + + + + + + {t("provisioningKeysValidUntilHint")} + + + + )} + /> + + +
+ + + + + + +
+
+ ); +} diff --git a/src/components/LayoutMobileMenu.tsx b/src/components/LayoutMobileMenu.tsx index e1c883a2b..854cad6db 100644 --- a/src/components/LayoutMobileMenu.tsx +++ b/src/components/LayoutMobileMenu.tsx @@ -93,7 +93,7 @@ export function LayoutMobileMenu({ ) } > - + diff --git a/src/components/LayoutSidebar.tsx b/src/components/LayoutSidebar.tsx index e9e2d61eb..1cd2131f7 100644 --- a/src/components/LayoutSidebar.tsx +++ b/src/components/LayoutSidebar.tsx @@ -169,8 +169,8 @@ export function LayoutSidebar({ > @@ -222,36 +222,34 @@ export function LayoutSidebar({ )} -
- -
+
{canShowProductUpdates && ( -
+
)} {build === "enterprise" && ( -
+
)} {build === "oss" && ( -
+
)} {build === "saas" && ( -
+
)} {!isSidebarCollapsed && ( -
+
{loadFooterLinks() ? ( <> {loadFooterLinks()!.map((link, index) => ( diff --git a/src/components/ProductUpdates.tsx b/src/components/ProductUpdates.tsx index 01689d9d7..76ab0252d 100644 --- a/src/components/ProductUpdates.tsx +++ b/src/components/ProductUpdates.tsx @@ -192,13 +192,13 @@ function ProductUpdatesListPopup({
- +

{t("productUpdateWhatsNew")} @@ -346,13 +346,13 @@ function NewVersionAvailable({ rel="noopener noreferrer" className={cn( "relative z-2 group cursor-pointer block", - "rounded-md border border-primary/30 bg-linear-to-br dark:from-primary/20 from-primary/20 via-background to-background p-2 py-3 w-full flex flex-col gap-2 text-sm", + "rounded-md border bg-secondary p-2 py-3 w-full flex flex-col gap-2 text-sm", "transition duration-300 ease-in-out", "data-closed:opacity-0 data-closed:translate-y-full" )} >

- +

{t("pangolinUpdateAvailable")}

diff --git a/src/components/SiteProvisioningKeysTable.tsx b/src/components/SiteProvisioningKeysTable.tsx index 3fb3eb872..df7fd241c 100644 --- a/src/components/SiteProvisioningKeysTable.tsx +++ b/src/components/SiteProvisioningKeysTable.tsx @@ -15,19 +15,27 @@ import { ArrowUpDown, MoreHorizontal } from "lucide-react"; import { useRouter } from "next/navigation"; import { useEffect, useState } from "react"; import CreateSiteProvisioningKeyCredenza from "@app/components/CreateSiteProvisioningKeyCredenza"; +import EditSiteProvisioningKeyCredenza from "@app/components/EditSiteProvisioningKeyCredenza"; import ConfirmDeleteDialog from "@app/components/ConfirmDeleteDialog"; import { toast } from "@app/hooks/useToast"; import { formatAxiosError } from "@app/lib/api"; import { createApiClient } from "@app/lib/api"; import { useEnvContext } from "@app/hooks/useEnvContext"; +import { usePaidStatus } from "@app/hooks/usePaidStatus"; import moment from "moment"; import { useTranslations } from "next-intl"; +import { build } from "@server/build"; +import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix"; export type SiteProvisioningKeyRow = { id: string; key: string; name: string; createdAt: string; + lastUsed: string | null; + maxBatchSize: number | null; + numUsed: number; + validUntil: string | null; }; type SiteProvisioningKeysTableProps = { @@ -47,8 +55,15 @@ export default function SiteProvisioningKeysTable({ const [rows, setRows] = useState(keys); const api = createApiClient(useEnvContext()); const t = useTranslations(); + const { isPaidUser } = usePaidStatus(); + const canUseSiteProvisioning = + isPaidUser(tierMatrix[TierFeature.SiteProvisioningKeys]) && + build !== "oss"; const [isRefreshing, setIsRefreshing] = useState(false); const [createOpen, setCreateOpen] = useState(false); + const [editOpen, setEditOpen] = useState(false); + const [editingKey, setEditingKey] = + useState(null); useEffect(() => { setRows(keys); @@ -121,6 +136,68 @@ export default function SiteProvisioningKeysTable({ return {r.key}; } }, + { + accessorKey: "maxBatchSize", + friendlyName: t("provisioningKeysMaxBatchSize"), + header: () => ( + {t("provisioningKeysMaxBatchSize")} + ), + cell: ({ row }) => { + const r = row.original; + return ( + + {r.maxBatchSize == null + ? t("provisioningKeysMaxBatchUnlimited") + : r.maxBatchSize} + + ); + } + }, + { + accessorKey: "numUsed", + friendlyName: t("provisioningKeysNumUsed"), + header: () => ( + {t("provisioningKeysNumUsed")} + ), + cell: ({ row }) => { + const r = row.original; + return {r.numUsed}; + } + }, + { + accessorKey: "validUntil", + friendlyName: t("provisioningKeysValidUntil"), + header: () => ( + {t("provisioningKeysValidUntil")} + ), + cell: ({ row }) => { + const r = row.original; + return ( + + {r.validUntil + ? moment(r.validUntil).format("lll") + : t("provisioningKeysNoExpiry")} + + ); + } + }, + { + accessorKey: "lastUsed", + friendlyName: t("provisioningKeysLastUsed"), + header: () => ( + {t("provisioningKeysLastUsed")} + ), + cell: ({ row }) => { + const r = row.original; + return ( + + {r.lastUsed + ? moment(r.lastUsed).format("lll") + : t("provisioningKeysNeverUsed")} + + ); + } + }, { accessorKey: "createdAt", friendlyName: t("createdAt"), @@ -149,6 +226,16 @@ export default function SiteProvisioningKeysTable({ { + setEditingKey(r); + setEditOpen(true); + }} + > + {t("edit")} + + { setSelected(r); setIsDeleteModalOpen(true); @@ -174,6 +261,18 @@ export default function SiteProvisioningKeysTable({ orgId={orgId} /> + { + setEditOpen(v); + if (!v) { + setEditingKey(null); + } + }} + orgId={orgId} + provisioningKey={editingKey} + /> + {selected && ( setCreateOpen(true)} + onAdd={() => { + if (canUseSiteProvisioning) { + setCreateOpen(true); + } + }} + addButtonDisabled={!canUseSiteProvisioning} onRefresh={refreshData} isRefreshing={isRefreshing} addButtonText={t("provisioningKeysAdd")} diff --git a/src/components/ui/data-table.tsx b/src/components/ui/data-table.tsx index 834c56e88..a0c11ffdf 100644 --- a/src/components/ui/data-table.tsx +++ b/src/components/ui/data-table.tsx @@ -171,6 +171,7 @@ type DataTableProps = { title?: string; addButtonText?: string; onAdd?: () => void; + addButtonDisabled?: boolean; onRefresh?: () => void; isRefreshing?: boolean; searchPlaceholder?: string; @@ -203,6 +204,7 @@ export function DataTable({ title, addButtonText, onAdd, + addButtonDisabled = false, onRefresh, isRefreshing, searchPlaceholder = "Search...", @@ -635,7 +637,7 @@ export function DataTable({ )} {onAdd && addButtonText && (
- From 660420ddef68f2c5abee349bcd98709c91aac869 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 24 Mar 2026 16:01:54 -0700 Subject: [PATCH 13/24] Disable everything if not paid --- src/app/[orgId]/settings/(private)/idp/create/page.tsx | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/app/[orgId]/settings/(private)/idp/create/page.tsx b/src/app/[orgId]/settings/(private)/idp/create/page.tsx index 4c783e9b2..fc2c6c382 100644 --- a/src/app/[orgId]/settings/(private)/idp/create/page.tsx +++ b/src/app/[orgId]/settings/(private)/idp/create/page.tsx @@ -275,6 +275,8 @@ export default function Page() { } } + const disabled = !isPaidUser(tierMatrix.orgOidc); + return ( <>
@@ -292,6 +294,9 @@ export default function Page() {
+ + +
@@ -812,9 +817,10 @@ export default function Page() {
+ ); } From 17eb93d045b94e7932421e5caca952224dc9ce0a Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 24 Mar 2026 16:12:13 -0700 Subject: [PATCH 14/24] Add better pooling controls --- server/db/pg/driver.ts | 35 ++++++++++++--------- server/db/pg/logsDriver.ts | 35 ++++++++++++--------- server/db/pg/poolConfig.ts | 63 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 28 deletions(-) create mode 100644 server/db/pg/poolConfig.ts diff --git a/server/db/pg/driver.ts b/server/db/pg/driver.ts index 5b357d060..9366e32e1 100644 --- a/server/db/pg/driver.ts +++ b/server/db/pg/driver.ts @@ -1,7 +1,7 @@ import { drizzle as DrizzlePostgres } from "drizzle-orm/node-postgres"; -import { Pool } from "pg"; import { readConfigFile } from "@server/lib/readConfigFile"; import { withReplicas } from "drizzle-orm/pg-core"; +import { createPool } from "./poolConfig"; function createDb() { const config = readConfigFile(); @@ -39,12 +39,17 @@ function createDb() { // Create connection pools instead of individual connections const poolConfig = config.postgres.pool; - const primaryPool = new Pool({ + const maxConnections = poolConfig?.max_connections || 20; + const idleTimeoutMs = poolConfig?.idle_timeout_ms || 30000; + const connectionTimeoutMs = poolConfig?.connection_timeout_ms || 5000; + + const primaryPool = createPool( connectionString, - max: poolConfig?.max_connections || 20, - idleTimeoutMillis: poolConfig?.idle_timeout_ms || 30000, - connectionTimeoutMillis: poolConfig?.connection_timeout_ms || 5000 - }); + maxConnections, + idleTimeoutMs, + connectionTimeoutMs, + "primary" + ); const replicas = []; @@ -55,14 +60,16 @@ function createDb() { }) ); } else { + const maxReplicaConnections = + poolConfig?.max_replica_connections || 20; for (const conn of replicaConnections) { - const replicaPool = new Pool({ - connectionString: conn.connection_string, - max: poolConfig?.max_replica_connections || 20, - idleTimeoutMillis: poolConfig?.idle_timeout_ms || 30000, - connectionTimeoutMillis: - poolConfig?.connection_timeout_ms || 5000 - }); + const replicaPool = createPool( + conn.connection_string, + maxReplicaConnections, + idleTimeoutMs, + connectionTimeoutMs, + "replica" + ); replicas.push( DrizzlePostgres(replicaPool, { logger: process.env.QUERY_LOGGING == "true" @@ -84,4 +91,4 @@ export default db; export const primaryDb = db.$primary; export type Transaction = Parameters< Parameters<(typeof db)["transaction"]>[0] ->[0]; +>[0]; \ No newline at end of file diff --git a/server/db/pg/logsDriver.ts b/server/db/pg/logsDriver.ts index 49e26f89f..146b8fb2f 100644 --- a/server/db/pg/logsDriver.ts +++ b/server/db/pg/logsDriver.ts @@ -1,9 +1,9 @@ import { drizzle as DrizzlePostgres } from "drizzle-orm/node-postgres"; -import { Pool } from "pg"; import { readConfigFile } from "@server/lib/readConfigFile"; import { withReplicas } from "drizzle-orm/pg-core"; import { build } from "@server/build"; import { db as mainDb, primaryDb as mainPrimaryDb } from "./driver"; +import { createPool } from "./poolConfig"; function createLogsDb() { // Only use separate logs database in SaaS builds @@ -42,12 +42,17 @@ function createLogsDb() { // Create separate connection pool for logs database const poolConfig = logsConfig?.pool || config.postgres?.pool; - const primaryPool = new Pool({ + const maxConnections = poolConfig?.max_connections || 20; + const idleTimeoutMs = poolConfig?.idle_timeout_ms || 30000; + const connectionTimeoutMs = poolConfig?.connection_timeout_ms || 5000; + + const primaryPool = createPool( connectionString, - max: poolConfig?.max_connections || 20, - idleTimeoutMillis: poolConfig?.idle_timeout_ms || 30000, - connectionTimeoutMillis: poolConfig?.connection_timeout_ms || 5000 - }); + maxConnections, + idleTimeoutMs, + connectionTimeoutMs, + "logs-primary" + ); const replicas = []; @@ -58,14 +63,16 @@ function createLogsDb() { }) ); } else { + const maxReplicaConnections = + poolConfig?.max_replica_connections || 20; for (const conn of replicaConnections) { - const replicaPool = new Pool({ - connectionString: conn.connection_string, - max: poolConfig?.max_replica_connections || 20, - idleTimeoutMillis: poolConfig?.idle_timeout_ms || 30000, - connectionTimeoutMillis: - poolConfig?.connection_timeout_ms || 5000 - }); + const replicaPool = createPool( + conn.connection_string, + maxReplicaConnections, + idleTimeoutMs, + connectionTimeoutMs, + "logs-replica" + ); replicas.push( DrizzlePostgres(replicaPool, { logger: process.env.QUERY_LOGGING == "true" @@ -84,4 +91,4 @@ function createLogsDb() { export const logsDb = createLogsDb(); export default logsDb; -export const primaryLogsDb = logsDb.$primary; +export const primaryLogsDb = logsDb.$primary; \ No newline at end of file diff --git a/server/db/pg/poolConfig.ts b/server/db/pg/poolConfig.ts new file mode 100644 index 000000000..f753121c1 --- /dev/null +++ b/server/db/pg/poolConfig.ts @@ -0,0 +1,63 @@ +import { Pool, PoolConfig } from "pg"; +import logger from "@server/logger"; + +export function createPoolConfig( + connectionString: string, + maxConnections: number, + idleTimeoutMs: number, + connectionTimeoutMs: number +): PoolConfig { + return { + connectionString, + max: maxConnections, + idleTimeoutMillis: idleTimeoutMs, + connectionTimeoutMillis: connectionTimeoutMs, + // TCP keepalive to prevent silent connection drops by NAT gateways, + // load balancers, and other intermediate network devices (e.g. AWS + // NAT Gateway drops idle TCP connections after ~350s) + keepAlive: true, + keepAliveInitialDelayMillis: 10000, // send first keepalive after 10s of idle + // Allow connections to be released and recreated more aggressively + // to avoid stale connections building up + allowExitOnIdle: false + }; +} + +export function attachPoolErrorHandlers(pool: Pool, label: string): void { + pool.on("error", (err) => { + // This catches errors on idle clients in the pool. Without this + // handler an unexpected disconnect would crash the process. + logger.error( + `Unexpected error on idle ${label} database client: ${err.message}` + ); + }); + + pool.on("connect", (client) => { + // Set a statement timeout on every new connection so a single slow + // query can't block the pool forever + client.query("SET statement_timeout = '30s'").catch((err: Error) => { + logger.warn( + `Failed to set statement_timeout on ${label} client: ${err.message}` + ); + }); + }); +} + +export function createPool( + connectionString: string, + maxConnections: number, + idleTimeoutMs: number, + connectionTimeoutMs: number, + label: string +): Pool { + const pool = new Pool( + createPoolConfig( + connectionString, + maxConnections, + idleTimeoutMs, + connectionTimeoutMs + ) + ); + attachPoolErrorHandlers(pool, label); + return pool; +} \ No newline at end of file From 3e3b02021cdcc60fadb2d246efbe1f466b8054a2 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 24 Mar 2026 16:26:56 -0700 Subject: [PATCH 15/24] Add ssh access log --- server/private/routers/ssh/signSshKey.ts | 19 +++++++++++++++++++ src/app/[orgId]/settings/logs/access/page.tsx | 16 ++++++++-------- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/server/private/routers/ssh/signSshKey.ts b/server/private/routers/ssh/signSshKey.ts index 5cffb4a34..c39d2e0ae 100644 --- a/server/private/routers/ssh/signSshKey.ts +++ b/server/private/routers/ssh/signSshKey.ts @@ -24,6 +24,7 @@ import { sites, userOrgs } from "@server/db"; +import { logAccessAudit } from "#private/lib/logAccessAudit"; import { isLicensedOrSubscribed } from "#private/lib/isLicencedOrSubscribed"; import { tierMatrix } from "@server/lib/billing/tierMatrix"; import response from "@server/lib/response"; @@ -463,6 +464,24 @@ export async function signSshKey( }) }); + await logAccessAudit({ + action: true, + type: "ssh", + orgId: orgId, + resourceId: resource.siteResourceId, + user: req.user + ? { username: req.user.username ?? "", userId: req.user.userId } + : undefined, + metadata: { + resourceName: resource.name, + siteId: resource.siteId, + sshUsername: usernameToUse, + sshHost: sshHost + }, + userAgent: req.headers["user-agent"], + requestIp: req.ip + }); + return response(res, { data: { certificate: cert.certificate, diff --git a/src/app/[orgId]/settings/logs/access/page.tsx b/src/app/[orgId]/settings/logs/access/page.tsx index 810022b98..dbb7b6708 100644 --- a/src/app/[orgId]/settings/logs/access/page.tsx +++ b/src/app/[orgId]/settings/logs/access/page.tsx @@ -493,7 +493,8 @@ export default function GeneralPage() { { value: "whitelistedEmail", label: "Whitelisted Email" - } + }, + { value: "ssh", label: "SSH" } ]} selectedValue={filters.type} onValueChange={(value) => @@ -507,13 +508,12 @@ export default function GeneralPage() { ); }, cell: ({ row }) => { - // should be capitalized first letter - return ( - - {row.original.type.charAt(0).toUpperCase() + - row.original.type.slice(1) || "-"} - - ); + const typeLabel = + row.original.type === "ssh" + ? "SSH" + : row.original.type.charAt(0).toUpperCase() + + row.original.type.slice(1); + return {typeLabel || "-"}; } }, { From 1f4cde5f7fae62fe6d6d75f3038ea68e68ad3c92 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 24 Mar 2026 18:13:57 -0700 Subject: [PATCH 16/24] Add license script --- license.py | 115 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 license.py diff --git a/license.py b/license.py new file mode 100644 index 000000000..865dfad7a --- /dev/null +++ b/license.py @@ -0,0 +1,115 @@ +import os +import sys + +# --- Configuration --- +# The header text to be added to the files. +HEADER_TEXT = """/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ +""" + +def should_add_header(file_path): + """ + Checks if a file should receive the commercial license header. + Returns True if 'private' is in the path or file content. + """ + # Check if 'private' is in the file path (case-insensitive) + if 'server/private' in file_path.lower(): + return True + + # Check if 'private' is in the file content (case-insensitive) + # try: + # with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + # content = f.read() + # if 'private' in content.lower(): + # return True + # except Exception as e: + # print(f"Could not read file {file_path}: {e}") + + return False + +def process_directory(root_dir): + """ + Recursively scans a directory and adds headers to qualifying .ts or .tsx files, + skipping any 'node_modules' directories. + """ + print(f"Scanning directory: {root_dir}") + files_processed = 0 + headers_added = 0 + + for root, dirs, files in os.walk(root_dir): + # --- MODIFICATION --- + # Exclude 'node_modules' directories from the scan to improve performance. + if 'node_modules' in dirs: + dirs.remove('node_modules') + + for file in files: + if file.endswith('.ts') or file.endswith('.tsx'): + file_path = os.path.join(root, file) + files_processed += 1 + + try: + with open(file_path, 'r+', encoding='utf-8') as f: + original_content = f.read() + has_header = original_content.startswith(HEADER_TEXT.strip()) + + if should_add_header(file_path): + # Add header only if it's not already there + if not has_header: + f.seek(0, 0) # Go to the beginning of the file + f.write(HEADER_TEXT.strip() + '\n\n' + original_content) + print(f"Added header to: {file_path}") + headers_added += 1 + else: + print(f"Header already exists in: {file_path}") + else: + # Remove header if it exists but shouldn't be there + if has_header: + # Find the end of the header and remove it (including following newlines) + header_with_newlines = HEADER_TEXT.strip() + '\n\n' + if original_content.startswith(header_with_newlines): + content_without_header = original_content[len(header_with_newlines):] + else: + # Handle case where there might be different newline patterns + header_end = len(HEADER_TEXT.strip()) + # Skip any newlines after the header + while header_end < len(original_content) and original_content[header_end] in '\n\r': + header_end += 1 + content_without_header = original_content[header_end:] + + f.seek(0) + f.write(content_without_header) + f.truncate() + print(f"Removed header from: {file_path}") + headers_added += 1 # Reusing counter for modifications + + except Exception as e: + print(f"Error processing file {file_path}: {e}") + + print("\n--- Scan Complete ---") + print(f"Total .ts or .tsx files found: {files_processed}") + print(f"Files modified (headers added/removed): {headers_added}") + + +if __name__ == "__main__": + # Get the target directory from the command line arguments. + # If no directory is provided, it uses the current directory ('.'). + if len(sys.argv) > 1: + target_directory = sys.argv[1] + else: + target_directory = '.' # Default to current directory + + if not os.path.isdir(target_directory): + print(f"Error: Directory '{target_directory}' not found.") + sys.exit(1) + + process_directory(os.path.abspath(target_directory)) From 348fcbcabf7016751c58314d6a3ceed84aa01b40 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 24 Mar 2026 17:39:43 -0700 Subject: [PATCH 17/24] Try to solve th problem --- server/cleanup.ts | 2 + server/private/cleanup.ts | 2 + server/private/routers/ws/ws.ts | 37 +- server/routers/newt/getNewtToken.ts | 22 +- server/routers/newt/handleNewtPingMessage.ts | 19 +- server/routers/newt/pingAccumulator.ts | 382 +++++++++++++++++++ server/routers/olm/getOlmToken.ts | 4 +- server/routers/olm/handleOlmPingMessage.ts | 23 +- server/routers/ws/messageHandlers.ts | 5 + server/routers/ws/ws.ts | 23 +- 10 files changed, 446 insertions(+), 73 deletions(-) create mode 100644 server/routers/newt/pingAccumulator.ts diff --git a/server/cleanup.ts b/server/cleanup.ts index 81cc31692..10e9f4cc3 100644 --- a/server/cleanup.ts +++ b/server/cleanup.ts @@ -1,9 +1,11 @@ import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage"; import { flushConnectionLogToDb } from "#dynamic/routers/newt"; import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth"; +import { stopPingAccumulator } from "@server/routers/newt/pingAccumulator"; import { cleanup as wsCleanup } from "#dynamic/routers/ws"; async function cleanup() { + await stopPingAccumulator(); await flushBandwidthToDb(); await flushConnectionLogToDb(); await flushSiteBandwidthToDb(); diff --git a/server/private/cleanup.ts b/server/private/cleanup.ts index 4b12f1b3c..17d823491 100644 --- a/server/private/cleanup.ts +++ b/server/private/cleanup.ts @@ -16,8 +16,10 @@ import { cleanup as wsCleanup } from "#private/routers/ws"; import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage"; import { flushConnectionLogToDb } from "#dynamic/routers/newt"; import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth"; +import { stopPingAccumulator } from "@server/routers/newt/pingAccumulator"; async function cleanup() { + await stopPingAccumulator(); await flushBandwidthToDb(); await flushConnectionLogToDb(); await flushSiteBandwidthToDb(); diff --git a/server/private/routers/ws/ws.ts b/server/private/routers/ws/ws.ts index 4bfda5da8..d96c55c91 100644 --- a/server/private/routers/ws/ws.ts +++ b/server/private/routers/ws/ws.ts @@ -30,6 +30,7 @@ import { } from "@server/db"; import { eq } from "drizzle-orm"; import { db } from "@server/db"; +import { recordPing } from "@server/routers/newt/pingAccumulator"; import { validateNewtSessionToken } from "@server/auth/sessions/newt"; import { validateOlmSessionToken } from "@server/auth/sessions/olm"; import logger from "@server/logger"; @@ -197,11 +198,7 @@ const connectedClients: Map = new Map(); // Config version tracking map (local to this node, resets on server restart) const clientConfigVersions: Map = new Map(); -// Tracks the last Unix timestamp (seconds) at which a ping was flushed to the -// DB for a given siteId. Resets on server restart which is fine – the first -// ping after startup will always write, re-establishing the online state. -const lastPingDbWrite: Map = new Map(); -const PING_DB_WRITE_INTERVAL = 45; // seconds + // Recovery tracking let isRedisRecoveryInProgress = false; @@ -853,32 +850,16 @@ const setupConnection = async ( ); }); - // Handle WebSocket protocol-level pings from older newt clients that do - // not send application-level "newt/ping" messages. Update the site's - // online state and lastPing timestamp so the offline checker treats them - // the same as modern newt clients. if (clientType === "newt") { const newtClient = client as Newt; - ws.on("ping", async () => { + ws.on("ping", () => { if (!newtClient.siteId) return; - const now = Math.floor(Date.now() / 1000); - const lastWrite = lastPingDbWrite.get(newtClient.siteId) ?? 0; - if (now - lastWrite < PING_DB_WRITE_INTERVAL) return; - lastPingDbWrite.set(newtClient.siteId, now); - try { - await db - .update(sites) - .set({ - online: true, - lastPing: now - }) - .where(eq(sites.siteId, newtClient.siteId)); - } catch (error) { - logger.error( - "Error updating newt site online state on WS ping", - { error } - ); - } + // Record the ping in the accumulator instead of writing to the + // database on every WS ping frame. The accumulator flushes all + // pending pings in a single batched UPDATE every ~10s, which + // prevents connection pool exhaustion under load (especially + // with cross-region latency to the database). + recordPing(newtClient.siteId); }); } diff --git a/server/routers/newt/getNewtToken.ts b/server/routers/newt/getNewtToken.ts index 637973582..bc3cca9fc 100644 --- a/server/routers/newt/getNewtToken.ts +++ b/server/routers/newt/getNewtToken.ts @@ -1,5 +1,5 @@ import { generateSessionToken } from "@server/auth/sessions/app"; -import { db } from "@server/db"; +import { db, newtSessions } from "@server/db"; import { newts } from "@server/db"; import HttpCode from "@server/types/HttpCode"; import response from "@server/lib/response"; @@ -92,6 +92,26 @@ export async function getNewtToken( ); } + const [existingSession] = await db + .select() + .from(newtSessions) + .where(eq(newtSessions.newtId, existingNewt.newtId)); + + // if the session still has time in the expires, reuse it + if (existingSession && (existingSession.expiresAt + 30 * 60 * 1000) > Date.now()) { + return response<{ token: string; serverVersion: string }>(res, { + data: { + token: existingSession.sessionId, + serverVersion: APP_VERSION + }, + success: true, + error: false, + message: "Token created successfully", + status: HttpCode.OK + }); + } + + // otherwise generate a new one const resToken = generateSessionToken(); await createNewtSession(resToken, existingNewt.newtId); diff --git a/server/routers/newt/handleNewtPingMessage.ts b/server/routers/newt/handleNewtPingMessage.ts index 319647b83..da25852a0 100644 --- a/server/routers/newt/handleNewtPingMessage.ts +++ b/server/routers/newt/handleNewtPingMessage.ts @@ -5,6 +5,7 @@ import { Newt } from "@server/db"; import { eq, lt, isNull, and, or } from "drizzle-orm"; import logger from "@server/logger"; import { sendNewtSyncMessage } from "./sync"; +import { recordPing } from "./pingAccumulator"; // Track if the offline checker interval is running let offlineCheckerInterval: NodeJS.Timeout | null = null; @@ -114,18 +115,12 @@ export const handleNewtPingMessage: MessageHandler = async (context) => { return; } - try { - // Mark the site as online and record the ping timestamp. - await db - .update(sites) - .set({ - online: true, - lastPing: Math.floor(Date.now() / 1000) - }) - .where(eq(sites.siteId, newt.siteId)); - } catch (error) { - logger.error("Error updating online state on newt ping", { error }); - } + // Record the ping in memory; it will be flushed to the database + // periodically by the ping accumulator (every ~10s) in a single + // batched UPDATE instead of one query per ping. This prevents + // connection pool exhaustion under load, especially with + // cross-region latency to the database. + recordPing(newt.siteId); // Check config version and sync if stale. const configVersion = await getClientConfigVersion(newt.newtId); diff --git a/server/routers/newt/pingAccumulator.ts b/server/routers/newt/pingAccumulator.ts new file mode 100644 index 000000000..83afd613e --- /dev/null +++ b/server/routers/newt/pingAccumulator.ts @@ -0,0 +1,382 @@ +import { db } from "@server/db"; +import { sites, clients, olms } from "@server/db"; +import { eq, inArray } from "drizzle-orm"; +import logger from "@server/logger"; + +/** + * Ping Accumulator + * + * Instead of writing to the database on every single newt/olm ping (which + * causes pool exhaustion under load, especially with cross-region latency), + * we accumulate pings in memory and flush them to the database periodically + * in a single batch. + * + * This is the same pattern used for bandwidth flushing in + * receiveBandwidth.ts and handleReceiveBandwidthMessage.ts. + * + * Supports two kinds of pings: + * - **Site pings** (from newts): update `sites.online` and `sites.lastPing` + * - **Client pings** (from OLMs): update `clients.online`, `clients.lastPing`, + * `clients.archived`, and optionally reset `olms.archived` + */ + +const FLUSH_INTERVAL_MS = 10_000; // Flush every 10 seconds +const MAX_RETRIES = 2; +const BASE_DELAY_MS = 50; + +// ── Site (newt) pings ────────────────────────────────────────────────── +// Map of siteId -> latest ping timestamp (unix seconds) +const pendingSitePings: Map = new Map(); + +// ── Client (OLM) pings ──────────────────────────────────────────────── +// Map of clientId -> latest ping timestamp (unix seconds) +const pendingClientPings: Map = new Map(); +// Set of olmIds whose `archived` flag should be reset to false +const pendingOlmArchiveResets: Set = new Set(); + +let flushTimer: NodeJS.Timeout | null = null; + +// ── Public API ───────────────────────────────────────────────────────── + +/** + * Record a ping for a newt site. This does NOT write to the database + * immediately. Instead it stores the latest ping timestamp in memory, + * to be flushed periodically by the background timer. + */ +export function recordSitePing(siteId: number): void { + const now = Math.floor(Date.now() / 1000); + pendingSitePings.set(siteId, now); +} + +/** @deprecated Use `recordSitePing` instead. Alias kept for existing call-sites. */ +export const recordPing = recordSitePing; + +/** + * Record a ping for an OLM client. Batches the `clients` table update + * (`online`, `lastPing`, `archived`) and, when `olmArchived` is true, + * also queues an `olms` table update to clear the archived flag. + */ +export function recordClientPing( + clientId: number, + olmId: string, + olmArchived: boolean +): void { + const now = Math.floor(Date.now() / 1000); + pendingClientPings.set(clientId, now); + if (olmArchived) { + pendingOlmArchiveResets.add(olmId); + } +} + +// ── Flush Logic ──────────────────────────────────────────────────────── + +/** + * Flush all accumulated site pings to the database. + */ +async function flushSitePingsToDb(): Promise { + if (pendingSitePings.size === 0) { + return; + } + + // Snapshot and clear so new pings arriving during the flush go into a + // fresh map for the next cycle. + const pingsToFlush = new Map(pendingSitePings); + pendingSitePings.clear(); + + // Sort by siteId for consistent lock ordering (prevents deadlocks) + const sortedEntries = Array.from(pingsToFlush.entries()).sort( + ([a], [b]) => a - b + ); + + const BATCH_SIZE = 50; + for (let i = 0; i < sortedEntries.length; i += BATCH_SIZE) { + const batch = sortedEntries.slice(i, i + BATCH_SIZE); + + try { + await withRetry(async () => { + // Group by timestamp for efficient bulk updates + const byTimestamp = new Map(); + for (const [siteId, timestamp] of batch) { + const group = byTimestamp.get(timestamp) || []; + group.push(siteId); + byTimestamp.set(timestamp, group); + } + + if (byTimestamp.size === 1) { + const [timestamp, siteIds] = Array.from( + byTimestamp.entries() + )[0]; + await db + .update(sites) + .set({ + online: true, + lastPing: timestamp + }) + .where(inArray(sites.siteId, siteIds)); + } else { + await db.transaction(async (tx) => { + for (const [timestamp, siteIds] of byTimestamp) { + await tx + .update(sites) + .set({ + online: true, + lastPing: timestamp + }) + .where(inArray(sites.siteId, siteIds)); + } + }); + } + }, "flushSitePingsToDb"); + } catch (error) { + logger.error( + `Failed to flush site ping batch (${batch.length} sites), re-queuing for next cycle`, + { error } + ); + for (const [siteId, timestamp] of batch) { + const existing = pendingSitePings.get(siteId); + if (!existing || existing < timestamp) { + pendingSitePings.set(siteId, timestamp); + } + } + } + } +} + +/** + * Flush all accumulated client (OLM) pings to the database. + */ +async function flushClientPingsToDb(): Promise { + if (pendingClientPings.size === 0 && pendingOlmArchiveResets.size === 0) { + return; + } + + // Snapshot and clear + const pingsToFlush = new Map(pendingClientPings); + pendingClientPings.clear(); + + const olmResetsToFlush = new Set(pendingOlmArchiveResets); + pendingOlmArchiveResets.clear(); + + // ── Flush client pings ───────────────────────────────────────────── + if (pingsToFlush.size > 0) { + const sortedEntries = Array.from(pingsToFlush.entries()).sort( + ([a], [b]) => a - b + ); + + const BATCH_SIZE = 50; + for (let i = 0; i < sortedEntries.length; i += BATCH_SIZE) { + const batch = sortedEntries.slice(i, i + BATCH_SIZE); + + try { + await withRetry(async () => { + const byTimestamp = new Map(); + for (const [clientId, timestamp] of batch) { + const group = byTimestamp.get(timestamp) || []; + group.push(clientId); + byTimestamp.set(timestamp, group); + } + + if (byTimestamp.size === 1) { + const [timestamp, clientIds] = Array.from( + byTimestamp.entries() + )[0]; + await db + .update(clients) + .set({ + lastPing: timestamp, + online: true, + archived: false + }) + .where(inArray(clients.clientId, clientIds)); + } else { + await db.transaction(async (tx) => { + for (const [timestamp, clientIds] of byTimestamp) { + await tx + .update(clients) + .set({ + lastPing: timestamp, + online: true, + archived: false + }) + .where( + inArray(clients.clientId, clientIds) + ); + } + }); + } + }, "flushClientPingsToDb"); + } catch (error) { + logger.error( + `Failed to flush client ping batch (${batch.length} clients), re-queuing for next cycle`, + { error } + ); + for (const [clientId, timestamp] of batch) { + const existing = pendingClientPings.get(clientId); + if (!existing || existing < timestamp) { + pendingClientPings.set(clientId, timestamp); + } + } + } + } + } + + // ── Flush OLM archive resets ─────────────────────────────────────── + if (olmResetsToFlush.size > 0) { + const olmIds = Array.from(olmResetsToFlush).sort(); + + const BATCH_SIZE = 50; + for (let i = 0; i < olmIds.length; i += BATCH_SIZE) { + const batch = olmIds.slice(i, i + BATCH_SIZE); + + try { + await withRetry(async () => { + await db + .update(olms) + .set({ archived: false }) + .where(inArray(olms.olmId, batch)); + }, "flushOlmArchiveResets"); + } catch (error) { + logger.error( + `Failed to flush OLM archive reset batch (${batch.length} olms), re-queuing for next cycle`, + { error } + ); + for (const olmId of batch) { + pendingOlmArchiveResets.add(olmId); + } + } + } + } +} + +/** + * Flush everything — called by the interval timer and during shutdown. + */ +export async function flushPingsToDb(): Promise { + await flushSitePingsToDb(); + await flushClientPingsToDb(); +} + +// ── Retry / Error Helpers ────────────────────────────────────────────── + +/** + * Simple retry wrapper with exponential backoff for transient errors + * (connection timeouts, unexpected disconnects). + */ +async function withRetry( + operation: () => Promise, + context: string +): Promise { + let attempt = 0; + while (true) { + try { + return await operation(); + } catch (error: any) { + if (isTransientError(error) && attempt < MAX_RETRIES) { + attempt++; + const baseDelay = Math.pow(2, attempt - 1) * BASE_DELAY_MS; + const jitter = Math.random() * baseDelay; + const delay = baseDelay + jitter; + logger.warn( + `Transient DB error in ${context}, retrying attempt ${attempt}/${MAX_RETRIES} after ${delay.toFixed(0)}ms` + ); + await new Promise((resolve) => setTimeout(resolve, delay)); + continue; + } + throw error; + } + } +} + +/** + * Detect transient connection errors that are safe to retry. + */ +function isTransientError(error: any): boolean { + if (!error) return false; + + const message = (error.message || "").toLowerCase(); + const causeMessage = (error.cause?.message || "").toLowerCase(); + const code = error.code || ""; + + // Connection timeout / terminated + if ( + message.includes("connection timeout") || + message.includes("connection terminated") || + message.includes("timeout exceeded when trying to connect") || + causeMessage.includes("connection terminated unexpectedly") || + causeMessage.includes("connection timeout") + ) { + return true; + } + + // PostgreSQL deadlock + if (code === "40P01" || message.includes("deadlock")) { + return true; + } + + // ECONNRESET, ECONNREFUSED, EPIPE + if ( + code === "ECONNRESET" || + code === "ECONNREFUSED" || + code === "EPIPE" || + code === "ETIMEDOUT" + ) { + return true; + } + + return false; +} + +// ── Lifecycle ────────────────────────────────────────────────────────── + +/** + * Start the background flush timer. Call this once at server startup. + */ +export function startPingAccumulator(): void { + if (flushTimer) { + return; // Already running + } + + flushTimer = setInterval(async () => { + try { + await flushPingsToDb(); + } catch (error) { + logger.error("Unhandled error in ping accumulator flush", { + error + }); + } + }, FLUSH_INTERVAL_MS); + + // Don't prevent the process from exiting + flushTimer.unref(); + + logger.info( + `Ping accumulator started (flush interval: ${FLUSH_INTERVAL_MS}ms)` + ); +} + +/** + * Stop the background flush timer and perform a final flush. + * Call this during graceful shutdown. + */ +export async function stopPingAccumulator(): Promise { + if (flushTimer) { + clearInterval(flushTimer); + flushTimer = null; + } + + // Final flush to persist any remaining pings + try { + await flushPingsToDb(); + } catch (error) { + logger.error("Error during final ping accumulator flush", { error }); + } + + logger.info("Ping accumulator stopped"); +} + +/** + * Get the number of pending (unflushed) pings. Useful for monitoring. + */ +export function getPendingPingCount(): number { + return pendingSitePings.size + pendingClientPings.size; +} \ No newline at end of file diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index 2734a63bc..027e7ec15 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -8,7 +8,9 @@ import { ExitNode, exitNodes, sites, - clientSitesAssociationsCache + clientSitesAssociationsCache, + olmSessions, + olmSessions } from "@server/db"; import { olms } from "@server/db"; import HttpCode from "@server/types/HttpCode"; diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index efcbf1696..0f520b234 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -3,6 +3,7 @@ import { db } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; import { clients, olms, Olm } from "@server/db"; import { eq, lt, isNull, and, or } from "drizzle-orm"; +import { recordClientPing } from "@server/routers/newt/pingAccumulator"; import logger from "@server/logger"; import { validateSessionToken } from "@server/auth/sessions/app"; import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy"; @@ -201,22 +202,12 @@ export const handleOlmPingMessage: MessageHandler = async (context) => { await sendOlmSyncMessage(olm, client); } - // Update the client's last ping timestamp - await db - .update(clients) - .set({ - lastPing: Math.floor(Date.now() / 1000), - online: true, - archived: false - }) - .where(eq(clients.clientId, olm.clientId)); - - if (olm.archived) { - await db - .update(olms) - .set({ archived: false }) - .where(eq(olms.olmId, olm.olmId)); - } + // Record the ping in memory; it will be flushed to the database + // periodically by the ping accumulator (every ~10s) in a single + // batched UPDATE instead of one query per ping. This prevents + // connection pool exhaustion under load, especially with + // cross-region latency to the database. + recordClientPing(olm.clientId, olm.olmId, !!olm.archived); } catch (error) { logger.error("Error handling ping message", { error }); } diff --git a/server/routers/ws/messageHandlers.ts b/server/routers/ws/messageHandlers.ts index 628caafd5..143e4d516 100644 --- a/server/routers/ws/messageHandlers.ts +++ b/server/routers/ws/messageHandlers.ts @@ -11,6 +11,7 @@ import { startNewtOfflineChecker, handleNewtDisconnectingMessage } from "../newt"; +import { startPingAccumulator } from "../newt/pingAccumulator"; import { handleOlmRegisterMessage, handleOlmRelayMessage, @@ -46,6 +47,10 @@ export const messageHandlers: Record = { "ws/round-trip/complete": handleRoundTripMessage }; +// Start the ping accumulator for all builds — it batches per-site online/lastPing +// updates into periodic bulk writes, preventing connection pool exhaustion. +startPingAccumulator(); + if (build != "saas") { startOlmOfflineChecker(); // this is to handle the offline check for olms startNewtOfflineChecker(); // this is to handle the offline check for newts diff --git a/server/routers/ws/ws.ts b/server/routers/ws/ws.ts index 08a7dbd4c..6e6312715 100644 --- a/server/routers/ws/ws.ts +++ b/server/routers/ws/ws.ts @@ -6,6 +6,7 @@ import { Socket } from "net"; import { Newt, newts, NewtSession, olms, Olm, OlmSession, sites } from "@server/db"; import { eq } from "drizzle-orm"; import { db } from "@server/db"; +import { recordPing } from "@server/routers/newt/pingAccumulator"; import { validateNewtSessionToken } from "@server/auth/sessions/newt"; import { validateOlmSessionToken } from "@server/auth/sessions/olm"; import { messageHandlers } from "./messageHandlers"; @@ -386,22 +387,14 @@ const setupConnection = async ( // the same as modern newt clients. if (clientType === "newt") { const newtClient = client as Newt; - ws.on("ping", async () => { + ws.on("ping", () => { if (!newtClient.siteId) return; - try { - await db - .update(sites) - .set({ - online: true, - lastPing: Math.floor(Date.now() / 1000) - }) - .where(eq(sites.siteId, newtClient.siteId)); - } catch (error) { - logger.error( - "Error updating newt site online state on WS ping", - { error } - ); - } + // Record the ping in the accumulator instead of writing to the + // database on every WS ping frame. The accumulator flushes all + // pending pings in a single batched UPDATE every ~10s, which + // prevents connection pool exhaustion under load (especially + // with cross-region latency to the database). + recordPing(newtClient.siteId); }); } From 6bb6cf8a48dd99f291657b02ed6b99fba4aab144 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 24 Mar 2026 17:54:51 -0700 Subject: [PATCH 18/24] Clean up --- server/private/routers/ws/ws.ts | 5 ----- server/routers/newt/getNewtToken.ts | 19 ------------------- server/routers/olm/getOlmToken.ts | 2 -- 3 files changed, 26 deletions(-) diff --git a/server/private/routers/ws/ws.ts b/server/private/routers/ws/ws.ts index d96c55c91..67b83f931 100644 --- a/server/private/routers/ws/ws.ts +++ b/server/private/routers/ws/ws.ts @@ -19,14 +19,9 @@ import { Socket } from "net"; import { Newt, newts, - NewtSession, - olms, Olm, - OlmSession, RemoteExitNode, - RemoteExitNodeSession, remoteExitNodes, - sites } from "@server/db"; import { eq } from "drizzle-orm"; import { db } from "@server/db"; diff --git a/server/routers/newt/getNewtToken.ts b/server/routers/newt/getNewtToken.ts index bc3cca9fc..9d3da7e97 100644 --- a/server/routers/newt/getNewtToken.ts +++ b/server/routers/newt/getNewtToken.ts @@ -92,25 +92,6 @@ export async function getNewtToken( ); } - const [existingSession] = await db - .select() - .from(newtSessions) - .where(eq(newtSessions.newtId, existingNewt.newtId)); - - // if the session still has time in the expires, reuse it - if (existingSession && (existingSession.expiresAt + 30 * 60 * 1000) > Date.now()) { - return response<{ token: string; serverVersion: string }>(res, { - data: { - token: existingSession.sessionId, - serverVersion: APP_VERSION - }, - success: true, - error: false, - message: "Token created successfully", - status: HttpCode.OK - }); - } - // otherwise generate a new one const resToken = generateSessionToken(); await createNewtSession(resToken, existingNewt.newtId); diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index 027e7ec15..741b29f0a 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -9,8 +9,6 @@ import { exitNodes, sites, clientSitesAssociationsCache, - olmSessions, - olmSessions } from "@server/db"; import { olms } from "@server/db"; import HttpCode from "@server/types/HttpCode"; From 9b84623d0c40a89b2aa1f9e8a7242b35e445d558 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 24 Mar 2026 18:12:51 -0700 Subject: [PATCH 19/24] Cache token for thundering hurd --- server/lib/tokenCache.ts | 22 ++++++ server/private/lib/cache.ts | 13 ++++ server/private/lib/tokenCache.ts | 77 +++++++++++++++++++ .../remoteExitNode/getRemoteExitNodeToken.ts | 25 ++++-- server/routers/newt/getNewtToken.ts | 18 ++++- server/routers/olm/getOlmToken.ts | 19 ++++- 6 files changed, 161 insertions(+), 13 deletions(-) create mode 100644 server/lib/tokenCache.ts create mode 100644 server/private/lib/tokenCache.ts diff --git a/server/lib/tokenCache.ts b/server/lib/tokenCache.ts new file mode 100644 index 000000000..022f46c15 --- /dev/null +++ b/server/lib/tokenCache.ts @@ -0,0 +1,22 @@ +/** + * Returns a cached plaintext token from Redis if one exists and decrypts + * cleanly, otherwise calls `createSession` to mint a fresh token, stores the + * encrypted value in Redis with the given TTL, and returns it. + * + * Failures at the Redis layer are non-fatal – the function always falls + * through to session creation so the caller is never blocked by a Redis outage. + * + * @param cacheKey Unique Redis key, e.g. `"newt:token_cache:abc123"` + * @param secret Server secret used for AES encryption/decryption + * @param ttlSeconds Cache TTL in seconds (should match session expiry) + * @param createSession Factory that mints a new session and returns its raw token + */ +export async function getOrCreateCachedToken( + cacheKey: string, + secret: string, + ttlSeconds: number, + createSession: () => Promise +): Promise { + const token = await createSession(); + return token; +} diff --git a/server/private/lib/cache.ts b/server/private/lib/cache.ts index e8c03ba3d..1a2006d46 100644 --- a/server/private/lib/cache.ts +++ b/server/private/lib/cache.ts @@ -1,3 +1,16 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + import NodeCache from "node-cache"; import logger from "@server/logger"; import { redisManager } from "@server/private/lib/redis"; diff --git a/server/private/lib/tokenCache.ts b/server/private/lib/tokenCache.ts new file mode 100644 index 000000000..bb6645688 --- /dev/null +++ b/server/private/lib/tokenCache.ts @@ -0,0 +1,77 @@ +/* + * This file is part of a proprietary work. + * + * Copyright (c) 2025 Fossorial, Inc. + * All rights reserved. + * + * This file is licensed under the Fossorial Commercial License. + * You may not use this file except in compliance with the License. + * Unauthorized use, copying, modification, or distribution is strictly prohibited. + * + * This file is not licensed under the AGPLv3. + */ + +import redisManager from "#dynamic/lib/redis"; +import { encrypt, decrypt } from "@server/lib/crypto"; +import logger from "@server/logger"; + +/** + * Returns a cached plaintext token from Redis if one exists and decrypts + * cleanly, otherwise calls `createSession` to mint a fresh token, stores the + * encrypted value in Redis with the given TTL, and returns it. + * + * Failures at the Redis layer are non-fatal – the function always falls + * through to session creation so the caller is never blocked by a Redis outage. + * + * @param cacheKey Unique Redis key, e.g. `"newt:token_cache:abc123"` + * @param secret Server secret used for AES encryption/decryption + * @param ttlSeconds Cache TTL in seconds (should match session expiry) + * @param createSession Factory that mints a new session and returns its raw token + */ +export async function getOrCreateCachedToken( + cacheKey: string, + secret: string, + ttlSeconds: number, + createSession: () => Promise +): Promise { + if (redisManager.isRedisEnabled()) { + try { + const cached = await redisManager.get(cacheKey); + if (cached) { + const token = decrypt(cached, secret); + if (token) { + logger.debug(`Token cache hit for key: ${cacheKey}`); + return token; + } + // Decryption produced an empty string – treat as a miss + logger.warn( + `Token cache decryption returned empty string for key: ${cacheKey}, treating as miss` + ); + } + } catch (e) { + logger.warn( + `Token cache read/decrypt failed for key ${cacheKey}, falling through to session creation:`, + e + ); + } + } + + const token = await createSession(); + + if (redisManager.isRedisEnabled()) { + try { + const encrypted = encrypt(token, secret); + await redisManager.set(cacheKey, encrypted, ttlSeconds); + logger.debug( + `Token cached in Redis for key: ${cacheKey} (TTL ${ttlSeconds}s)` + ); + } catch (e) { + logger.warn( + `Token cache write failed for key ${cacheKey} (session was still created):`, + e + ); + } + } + + return token; +} diff --git a/server/private/routers/remoteExitNode/getRemoteExitNodeToken.ts b/server/private/routers/remoteExitNode/getRemoteExitNodeToken.ts index 24f0de159..025e2d34e 100644 --- a/server/private/routers/remoteExitNode/getRemoteExitNodeToken.ts +++ b/server/private/routers/remoteExitNode/getRemoteExitNodeToken.ts @@ -23,8 +23,10 @@ import { z } from "zod"; import { fromError } from "zod-validation-error"; import { createRemoteExitNodeSession, - validateRemoteExitNodeSessionToken + validateRemoteExitNodeSessionToken, + EXPIRES } from "#private/auth/sessions/remoteExitNode"; +import { getOrCreateCachedToken } from "@server/private/lib/tokenCache"; import { verifyPassword } from "@server/auth/password"; import logger from "@server/logger"; import config from "@server/lib/config"; @@ -103,14 +105,23 @@ export async function getRemoteExitNodeToken( ); } - const resToken = generateSessionToken(); - await createRemoteExitNodeSession( - resToken, - existingRemoteExitNode.remoteExitNodeId + // Return a cached token if one exists to prevent thundering herd on + // simultaneous restarts; falls back to creating a fresh session when + // Redis is unavailable or the cache has expired. + const resToken = await getOrCreateCachedToken( + `remote_exit_node:token_cache:${existingRemoteExitNode.remoteExitNodeId}`, + config.getRawConfig().server.secret!, + Math.floor(EXPIRES / 1000), + async () => { + const token = generateSessionToken(); + await createRemoteExitNodeSession( + token, + existingRemoteExitNode.remoteExitNodeId + ); + return token; + } ); - // logger.debug(`Created RemoteExitNode token response: ${JSON.stringify(resToken)}`); - return response<{ token: string }>(res, { data: { token: resToken diff --git a/server/routers/newt/getNewtToken.ts b/server/routers/newt/getNewtToken.ts index 9d3da7e97..c5abb9968 100644 --- a/server/routers/newt/getNewtToken.ts +++ b/server/routers/newt/getNewtToken.ts @@ -1,6 +1,8 @@ import { generateSessionToken } from "@server/auth/sessions/app"; import { db, newtSessions } from "@server/db"; import { newts } from "@server/db"; +import { getOrCreateCachedToken } from "#dynamic/lib/tokenCache"; +import { EXPIRES } from "@server/auth/sessions/newt"; import HttpCode from "@server/types/HttpCode"; import response from "@server/lib/response"; import { eq } from "drizzle-orm"; @@ -92,9 +94,19 @@ export async function getNewtToken( ); } - // otherwise generate a new one - const resToken = generateSessionToken(); - await createNewtSession(resToken, existingNewt.newtId); + // Return a cached token if one exists to prevent thundering herd on + // simultaneous restarts; falls back to creating a fresh session when + // Redis is unavailable or the cache has expired. + const resToken = await getOrCreateCachedToken( + `newt:token_cache:${existingNewt.newtId}`, + config.getRawConfig().server.secret!, + Math.floor(EXPIRES / 1000), + async () => { + const token = generateSessionToken(); + await createNewtSession(token, existingNewt.newtId); + return token; + } + ); return response<{ token: string; serverVersion: string }>(res, { data: { diff --git a/server/routers/olm/getOlmToken.ts b/server/routers/olm/getOlmToken.ts index 741b29f0a..5b8411eb7 100644 --- a/server/routers/olm/getOlmToken.ts +++ b/server/routers/olm/getOlmToken.ts @@ -20,8 +20,10 @@ import { z } from "zod"; import { fromError } from "zod-validation-error"; import { createOlmSession, - validateOlmSessionToken + validateOlmSessionToken, + EXPIRES } from "@server/auth/sessions/olm"; +import { getOrCreateCachedToken } from "#dynamic/lib/tokenCache"; import { verifyPassword } from "@server/auth/password"; import logger from "@server/logger"; import config from "@server/lib/config"; @@ -132,8 +134,19 @@ export async function getOlmToken( logger.debug("Creating new olm session token"); - const resToken = generateSessionToken(); - await createOlmSession(resToken, existingOlm.olmId); + // Return a cached token if one exists to prevent thundering herd on + // simultaneous restarts; falls back to creating a fresh session when + // Redis is unavailable or the cache has expired. + const resToken = await getOrCreateCachedToken( + `olm:token_cache:${existingOlm.olmId}`, + config.getRawConfig().server.secret!, + Math.floor(EXPIRES / 1000), + async () => { + const token = generateSessionToken(); + await createOlmSession(token, existingOlm.olmId); + return token; + } + ); let clientIdToUse; if (orgId) { From 99a064b77af0d84f8d996885bb4ee5db5dfb5305 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 24 Mar 2026 20:27:34 -0700 Subject: [PATCH 20/24] Fix import problems --- server/private/lib/tokenCache.ts | 2 +- server/private/routers/ws/ws.ts | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/server/private/lib/tokenCache.ts b/server/private/lib/tokenCache.ts index bb6645688..284f1d698 100644 --- a/server/private/lib/tokenCache.ts +++ b/server/private/lib/tokenCache.ts @@ -11,7 +11,7 @@ * This file is not licensed under the AGPLv3. */ -import redisManager from "#dynamic/lib/redis"; +import redisManager from "#private/lib/redis"; import { encrypt, decrypt } from "@server/lib/crypto"; import logger from "@server/logger"; diff --git a/server/private/routers/ws/ws.ts b/server/private/routers/ws/ws.ts index 67b83f931..21f4fad37 100644 --- a/server/private/routers/ws/ws.ts +++ b/server/private/routers/ws/ws.ts @@ -20,6 +20,7 @@ import { Newt, newts, Olm, + olms, RemoteExitNode, remoteExitNodes, } from "@server/db"; From c80c7df1d06266fbea9926159b86fd7bb675b178 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 25 Mar 2026 20:33:58 -0700 Subject: [PATCH 21/24] Batch set bandwidth --- server/routers/gerbil/receiveBandwidth.ts | 146 +++++++++++++--------- 1 file changed, 84 insertions(+), 62 deletions(-) diff --git a/server/routers/gerbil/receiveBandwidth.ts b/server/routers/gerbil/receiveBandwidth.ts index b73ce986d..042c844aa 100644 --- a/server/routers/gerbil/receiveBandwidth.ts +++ b/server/routers/gerbil/receiveBandwidth.ts @@ -1,6 +1,5 @@ import { Request, Response, NextFunction } from "express"; -import { eq, sql } from "drizzle-orm"; -import { sites } from "@server/db"; +import { sql } from "drizzle-orm"; import { db } from "@server/db"; import logger from "@server/logger"; import createHttpError from "http-errors"; @@ -31,7 +30,10 @@ const MAX_RETRIES = 3; const BASE_DELAY_MS = 50; // How often to flush accumulated bandwidth data to the database -const FLUSH_INTERVAL_MS = 30_000; // 30 seconds +const FLUSH_INTERVAL_MS = 300_000; // 300 seconds + +// Maximum number of sites to include in a single batch UPDATE statement +const BATCH_CHUNK_SIZE = 250; // In-memory accumulator: publicKey -> AccumulatorEntry let accumulator = new Map(); @@ -75,13 +77,33 @@ async function withDeadlockRetry( } } +/** + * Execute a raw SQL query that returns rows, in a way that works across both + * the PostgreSQL driver (which exposes `execute`) and the SQLite driver (which + * exposes `all`). Drizzle's typed query builder doesn't support bulk + * UPDATE … FROM (VALUES …) natively, so we drop to raw SQL here. + */ +async function dbQueryRows>( + query: Parameters<(typeof sql)["join"]>[0][number] +): Promise { + const anyDb = db as any; + if (typeof anyDb.execute === "function") { + // PostgreSQL (node-postgres via Drizzle) — returns { rows: [...] } or an array + const result = await anyDb.execute(query); + return (Array.isArray(result) ? result : (result.rows ?? [])) as T[]; + } + // SQLite (better-sqlite3 via Drizzle) — returns an array directly + return (await anyDb.all(query)) as T[]; +} + /** * Flush all accumulated site bandwidth data to the database. * * Swaps out the accumulator before writing so that any bandwidth messages * received during the flush are captured in the new accumulator rather than - * being lost or causing contention. Entries that fail to write are re-queued - * back into the accumulator so they will be retried on the next flush. + * being lost or causing contention. Sites are updated in chunks via a single + * batch UPDATE per chunk. Failed chunks are discarded — exact per-flush + * accuracy is not critical and re-queuing is not worth the added complexity. * * This function is exported so that the application's graceful-shutdown * cleanup handler can call it before the process exits. @@ -108,76 +130,76 @@ export async function flushSiteBandwidthToDb(): Promise { `Flushing accumulated bandwidth data for ${sortedEntries.length} site(s) to the database` ); - // Aggregate billing usage by org, collected during the DB update loop. + // Build a lookup so post-processing can reach each entry by publicKey. + const snapshotMap = new Map(sortedEntries); + + // Aggregate billing usage by org across all chunks. const orgUsageMap = new Map(); - for (const [publicKey, { bytesIn, bytesOut, exitNodeId, calcUsage }] of sortedEntries) { + // Process in chunks so individual queries stay at a reasonable size. + for (let i = 0; i < sortedEntries.length; i += BATCH_CHUNK_SIZE) { + const chunk = sortedEntries.slice(i, i + BATCH_CHUNK_SIZE); + const chunkEnd = i + chunk.length - 1; + + // Build a parameterised VALUES list: (pubKey, bytesIn, bytesOut), ... + // Both PostgreSQL and SQLite (≥ 3.33.0, which better-sqlite3 bundles) + // support UPDATE … FROM (VALUES …), letting us update the whole chunk + // in a single query instead of N individual round-trips. + const valuesList = chunk.map(([publicKey, { bytesIn, bytesOut }]) => + sql`(${publicKey}, ${bytesIn}, ${bytesOut})` + ); + const valuesClause = sql.join(valuesList, sql`, `); + + let rows: { orgId: string; pubKey: string }[] = []; + try { - const updatedSite = await withDeadlockRetry(async () => { - const [result] = await db - .update(sites) - .set({ - megabytesOut: sql`COALESCE(${sites.megabytesOut}, 0) + ${bytesIn}`, - megabytesIn: sql`COALESCE(${sites.megabytesIn}, 0) + ${bytesOut}`, - lastBandwidthUpdate: currentTime, - }) - .where(eq(sites.pubKey, publicKey)) - .returning({ - orgId: sites.orgId, - siteId: sites.siteId - }); - return result; - }, `flush bandwidth for site ${publicKey}`); - - if (updatedSite) { - if (exitNodeId) { - const notAllowed = await checkExitNodeOrg( - exitNodeId, - updatedSite.orgId - ); - if (notAllowed) { - logger.warn( - `Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}` - ); - // Skip usage tracking for this site but continue - // processing the rest. - continue; - } - } - - if (calcUsage) { - const totalBandwidth = bytesIn + bytesOut; - const current = orgUsageMap.get(updatedSite.orgId) ?? 0; - orgUsageMap.set(updatedSite.orgId, current + totalBandwidth); - } - } + rows = await withDeadlockRetry(async () => { + return dbQueryRows<{ orgId: string; pubKey: string }>(sql` + UPDATE sites + SET + "bytesOut" = COALESCE("bytesOut", 0) + v.bytes_in, + "bytesIn" = COALESCE("bytesIn", 0) + v.bytes_out, + "lastBandwidthUpdate" = ${currentTime} + FROM (VALUES ${valuesClause}) AS v(pub_key, bytes_in, bytes_out) + WHERE sites."pubKey" = v.pub_key + RETURNING sites."orgId" AS "orgId", sites."pubKey" AS "pubKey" + `); + }, `flush bandwidth chunk [${i}–${chunkEnd}]`); } catch (error) { logger.error( - `Failed to flush bandwidth for site ${publicKey}:`, + `Failed to flush bandwidth chunk [${i}–${chunkEnd}], discarding ${chunk.length} site(s):`, error ); + // Discard the chunk — exact per-flush accuracy is not critical. + continue; + } - // Re-queue the failed entry so it is retried on the next flush - // rather than silently dropped. - const existing = accumulator.get(publicKey); - if (existing) { - existing.bytesIn += bytesIn; - existing.bytesOut += bytesOut; - } else { - accumulator.set(publicKey, { - bytesIn, - bytesOut, - exitNodeId, - calcUsage - }); + // Collect billing usage from the returned rows. + for (const { orgId, pubKey } of rows) { + const entry = snapshotMap.get(pubKey); + if (!entry) continue; + + const { bytesIn, bytesOut, exitNodeId, calcUsage } = entry; + + if (exitNodeId) { + const notAllowed = await checkExitNodeOrg(exitNodeId, orgId); + if (notAllowed) { + logger.warn( + `Exit node ${exitNodeId} is not allowed for org ${orgId}` + ); + continue; + } + } + + if (calcUsage) { + const current = orgUsageMap.get(orgId) ?? 0; + orgUsageMap.set(orgId, current + bytesIn + bytesOut); } } } - // Process billing usage updates outside the site-update loop to keep - // lock scope small and concerns separated. + // Process billing usage updates after all chunks are written. if (orgUsageMap.size > 0) { - // Sort org IDs for consistent lock ordering. const sortedOrgIds = [...orgUsageMap.keys()].sort(); for (const orgId of sortedOrgIds) { From b4ca6432dba2ea012c962da5ce0ffb03d8b20f2d Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 26 Mar 2026 16:24:44 -0700 Subject: [PATCH 22/24] Fix double id --- src/components/CreateSiteProvisioningKeyCredenza.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/components/CreateSiteProvisioningKeyCredenza.tsx b/src/components/CreateSiteProvisioningKeyCredenza.tsx index 70c48ff08..e18bbbd06 100644 --- a/src/components/CreateSiteProvisioningKeyCredenza.tsx +++ b/src/components/CreateSiteProvisioningKeyCredenza.tsx @@ -150,7 +150,7 @@ export default function CreateSiteProvisioningKeyCredenza({ const credential = created && - `${created.siteProvisioningKeyId}.${created.siteProvisioningKey}`; + created.siteProvisioningKey; const unlimitedBatchSize = form.watch("unlimitedBatchSize"); From 19a686b3e4040f91d125184b209517992a989dc8 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 26 Mar 2026 16:49:43 -0700 Subject: [PATCH 23/24] Add restrictions around provisioning key --- server/routers/newt/registerNewt.ts | 45 +++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/server/routers/newt/registerNewt.ts b/server/routers/newt/registerNewt.ts index b999eb35c..427ac173f 100644 --- a/server/routers/newt/registerNewt.ts +++ b/server/routers/newt/registerNewt.ts @@ -14,7 +14,7 @@ import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; import logger from "@server/logger"; -import { eq, and } from "drizzle-orm"; +import { eq, and, sql } from "drizzle-orm"; import { fromError } from "zod-validation-error"; import { verifyPassword, hashPassword } from "@server/auth/password"; import { @@ -26,6 +26,8 @@ import moment from "moment"; import { build } from "@server/build"; import { usageService } from "@server/lib/billing/usageService"; import { FeatureId } from "@server/lib/billing"; +import { INSPECT_MAX_BYTES } from "buffer"; +import { v } from "@faker-js/faker/dist/airline-Dz1uGqgJ"; const bodySchema = z.object({ provisioningKey: z.string().nonempty() @@ -77,7 +79,10 @@ export async function registerNewt( siteProvisioningKeys.siteProvisioningKeyId, siteProvisioningKeyHash: siteProvisioningKeys.siteProvisioningKeyHash, - orgId: siteProvisioningKeyOrg.orgId + orgId: siteProvisioningKeyOrg.orgId, + maxBatchSize: siteProvisioningKeys.maxBatchSize, + numUsed: siteProvisioningKeys.numUsed, + validUntil: siteProvisioningKeys.validUntil }) .from(siteProvisioningKeys) .innerJoin( @@ -118,13 +123,28 @@ export async function registerNewt( ); } + if (keyRecord.maxBatchSize && keyRecord.numUsed >= keyRecord.maxBatchSize) { + return next( + createHttpError( + HttpCode.UNAUTHORIZED, + "Provisioning key has reached its maximum usage" + ) + ); + } + + if (keyRecord.validUntil && new Date(keyRecord.validUntil) < new Date()) { + return next( + createHttpError( + HttpCode.UNAUTHORIZED, + "Provisioning key has expired" + ) + ); + } + const { orgId } = keyRecord; // Verify the org exists - const [org] = await db - .select() - .from(orgs) - .where(eq(orgs.orgId, orgId)); + const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId)); if (!org) { return next( createHttpError(HttpCode.NOT_FOUND, "Organization not found") @@ -208,7 +228,11 @@ export async function registerNewt( // Consume the provisioning key — cascade removes siteProvisioningKeyOrg await trx - .delete(siteProvisioningKeys) + .update(siteProvisioningKeys) + .set({ + lastUsed: moment().toISOString(), + numUsed: sql`${siteProvisioningKeys.numUsed} + 1` + }) .where( eq( siteProvisioningKeys.siteProvisioningKeyId, @@ -236,10 +260,7 @@ export async function registerNewt( } catch (error) { logger.error(error); return next( - createHttpError( - HttpCode.INTERNAL_SERVER_ERROR, - "An error occurred" - ) + createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred") ); } -} \ No newline at end of file +} From e05af54f76492ad62b31ad75494fa444a6bf4fa7 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 26 Mar 2026 21:36:51 -0700 Subject: [PATCH 24/24] Use standard component --- .../CreateSiteProvisioningKeyCredenza.tsx | 111 ++++++++++++++---- .../EditSiteProvisioningKeyCredenza.tsx | 93 +++++++++++---- 2 files changed, 162 insertions(+), 42 deletions(-) diff --git a/src/components/CreateSiteProvisioningKeyCredenza.tsx b/src/components/CreateSiteProvisioningKeyCredenza.tsx index e18bbbd06..3a1c7c372 100644 --- a/src/components/CreateSiteProvisioningKeyCredenza.tsx +++ b/src/components/CreateSiteProvisioningKeyCredenza.tsx @@ -36,6 +36,10 @@ import { useForm } from "react-hook-form"; import { z } from "zod"; import { zodResolver } from "@hookform/resolvers/zod"; import CopyTextBox from "@app/components/CopyTextBox"; +import { + DateTimePicker, + DateTimeValue +} from "@app/components/DateTimePicker"; const FORM_ID = "create-site-provisioning-key-form"; @@ -261,27 +265,92 @@ export default function CreateSiteProvisioningKeyCredenza({ ( - - - {t( - "provisioningKeysValidUntil" - )} - - - - - - {t( - "provisioningKeysValidUntilHint" - )} - - - - )} + render={({ field }) => { + const dateTimeValue: DateTimeValue = + (() => { + if (!field.value) return {}; + const d = new Date( + field.value + ); + if (isNaN(d.getTime())) + return {}; + const hours = d + .getHours() + .toString() + .padStart(2, "0"); + const minutes = d + .getMinutes() + .toString() + .padStart(2, "0"); + const seconds = d + .getSeconds() + .toString() + .padStart(2, "0"); + return { + date: d, + time: `${hours}:${minutes}:${seconds}` + }; + })(); + + return ( + + + {t( + "provisioningKeysValidUntil" + )} + + + { + if (!value.date) { + field.onChange( + "" + ); + return; + } + const d = new Date( + value.date + ); + if (value.time) { + const [ + h, + m, + s + ] = + value.time.split( + ":" + ); + d.setHours( + parseInt( + h, + 10 + ), + parseInt( + m, + 10 + ), + parseInt( + s || "0", + 10 + ) + ); + } + field.onChange( + d.toISOString() + ); + }} + /> + + + {t( + "provisioningKeysValidUntilHint" + )} + + + + ); + }} /> diff --git a/src/components/EditSiteProvisioningKeyCredenza.tsx b/src/components/EditSiteProvisioningKeyCredenza.tsx index 9603374d5..138190edc 100644 --- a/src/components/EditSiteProvisioningKeyCredenza.tsx +++ b/src/components/EditSiteProvisioningKeyCredenza.tsx @@ -33,7 +33,10 @@ import { useEffect, useState } from "react"; import { useForm } from "react-hook-form"; import { z } from "zod"; import { zodResolver } from "@hookform/resolvers/zod"; -import moment from "moment"; +import { + DateTimePicker, + DateTimeValue +} from "@app/components/DateTimePicker"; const FORM_ID = "edit-site-provisioning-key-form"; @@ -109,9 +112,7 @@ export default function EditSiteProvisioningKeyCredenza({ name: provisioningKey.name, unlimitedBatchSize: provisioningKey.maxBatchSize == null, maxBatchSize: provisioningKey.maxBatchSize ?? 100, - validUntil: provisioningKey.validUntil - ? moment(provisioningKey.validUntil).format("YYYY-MM-DDTHH:mm") - : "" + validUntil: provisioningKey.validUntil ?? "" }); }, [open, provisioningKey, form]); @@ -257,23 +258,73 @@ export default function EditSiteProvisioningKeyCredenza({ ( - - - {t("provisioningKeysValidUntil")} - - - - - - {t("provisioningKeysValidUntilHint")} - - - - )} + render={({ field }) => { + const dateTimeValue: DateTimeValue = + (() => { + if (!field.value) return {}; + const d = new Date(field.value); + if (isNaN(d.getTime())) return {}; + const hours = d + .getHours() + .toString() + .padStart(2, "0"); + const minutes = d + .getMinutes() + .toString() + .padStart(2, "0"); + const seconds = d + .getSeconds() + .toString() + .padStart(2, "0"); + return { + date: d, + time: `${hours}:${minutes}:${seconds}` + }; + })(); + + return ( + + + {t("provisioningKeysValidUntil")} + + + { + if (!value.date) { + field.onChange(""); + return; + } + const d = new Date( + value.date + ); + if (value.time) { + const [h, m, s] = + value.time.split( + ":" + ); + d.setHours( + parseInt(h, 10), + parseInt(m, 10), + parseInt( + s || "0", + 10 + ) + ); + } + field.onChange( + d.toISOString() + ); + }} + /> + + + {t("provisioningKeysValidUntilHint")} + + + + ); + }} />