From 48abb9e98c696930ef27d1ce96efa01db0fae74a Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 8 Apr 2026 22:04:12 -0400 Subject: [PATCH] Breakout sites tables --- server/db/pg/schema/schema.ts | 49 ++++++++++++++--- server/db/sqlite/schema/schema.ts | 49 ++++++++++++++--- server/lib/telemetry.ts | 9 ++-- .../hooks/handleSubscriptionUpdated.ts | 14 +++-- server/routers/client/listClients.ts | 20 ++++--- server/routers/client/listUserDevices.ts | 20 ++++--- server/routers/gerbil/receiveBandwidth.ts | 53 ++++++++++++------ server/routers/newt/handleNewtPingMessage.ts | 33 ++++-------- .../newt/handleReceiveBandwidthMessage.ts | 39 ++++++++++---- server/routers/newt/pingAccumulator.ts | 54 +++++++++++-------- server/routers/olm/handleOlmPingMessage.ts | 32 +++++++---- server/routers/org/resetOrgBandwidth.ts | 13 +++-- server/routers/site/listSites.ts | 20 ++++--- server/setup/migrationsPg.ts | 4 +- server/setup/migrationsSqlite.ts | 4 +- 15 files changed, 283 insertions(+), 130 deletions(-) diff --git a/server/db/pg/schema/schema.ts b/server/db/pg/schema/schema.ts index bde3e9aec..8a0ff39e1 100644 --- a/server/db/pg/schema/schema.ts +++ b/server/db/pg/schema/schema.ts @@ -89,12 +89,8 @@ export const sites = pgTable("sites", { name: varchar("name").notNull(), pubKey: varchar("pubKey"), subnet: varchar("subnet"), - megabytesIn: real("bytesIn").default(0), - megabytesOut: real("bytesOut").default(0), - lastBandwidthUpdate: varchar("lastBandwidthUpdate"), type: varchar("type").notNull(), // "newt" or "wireguard" online: boolean("online").notNull().default(false), - lastPing: integer("lastPing"), address: varchar("address"), endpoint: varchar("endpoint"), publicKey: varchar("publicKey"), @@ -729,10 +725,7 @@ export const clients = pgTable("clients", { name: varchar("name").notNull(), pubKey: varchar("pubKey"), subnet: varchar("subnet").notNull(), - megabytesIn: real("bytesIn"), - megabytesOut: real("bytesOut"), - lastBandwidthUpdate: varchar("lastBandwidthUpdate"), - lastPing: integer("lastPing"), + type: varchar("type").notNull(), // "olm" online: boolean("online").notNull().default(false), // endpoint: varchar("endpoint"), @@ -745,6 +738,42 @@ export const clients = pgTable("clients", { >() }); +export const sitePing = pgTable("sitePing", { + siteId: integer("siteId") + .primaryKey() + .references(() => sites.siteId, { onDelete: "cascade" }) + .notNull(), + lastPing: integer("lastPing") +}); + +export const siteBandwidth = pgTable("siteBandwidth", { + siteId: integer("siteId") + .primaryKey() + .references(() => sites.siteId, { onDelete: "cascade" }) + .notNull(), + megabytesIn: real("bytesIn").default(0), + megabytesOut: real("bytesOut").default(0), + lastBandwidthUpdate: integer("lastBandwidthUpdate") // unix epoch +}); + +export const clientPing = pgTable("clientPing", { + clientId: integer("clientId") + .primaryKey() + .references(() => clients.clientId, { onDelete: "cascade" }) + .notNull(), + lastPing: integer("lastPing") +}); + +export const clientBandwidth = pgTable("clientBandwidth", { + clientId: integer("clientId") + .primaryKey() + .references(() => clients.clientId, { onDelete: "cascade" }) + .notNull(), + megabytesIn: real("bytesIn"), + megabytesOut: real("bytesOut"), + lastBandwidthUpdate: integer("lastBandwidthUpdate") // unix epoch +}); + export const clientSitesAssociationsCache = pgTable( "clientSitesAssociationsCache", { @@ -1106,3 +1135,7 @@ export type RequestAuditLog = InferSelectModel; export type RoundTripMessageTracker = InferSelectModel< typeof roundTripMessageTracker >; +export type SitePing = typeof sitePing.$inferSelect; +export type SiteBandwidth = typeof siteBandwidth.$inferSelect; +export type ClientPing = typeof clientPing.$inferSelect; +export type ClientBandwidth = typeof clientBandwidth.$inferSelect; diff --git a/server/db/sqlite/schema/schema.ts b/server/db/sqlite/schema/schema.ts index 1fb04ef14..e62fefece 100644 --- a/server/db/sqlite/schema/schema.ts +++ b/server/db/sqlite/schema/schema.ts @@ -95,12 +95,8 @@ export const sites = sqliteTable("sites", { name: text("name").notNull(), pubKey: text("pubKey"), subnet: text("subnet"), - megabytesIn: integer("bytesIn").default(0), - megabytesOut: integer("bytesOut").default(0), - lastBandwidthUpdate: text("lastBandwidthUpdate"), type: text("type").notNull(), // "newt" or "wireguard" online: integer("online", { mode: "boolean" }).notNull().default(false), - lastPing: integer("lastPing"), // exit node stuff that is how to connect to the site when it has a wg server address: text("address"), // this is the address of the wireguard interface in newt @@ -399,10 +395,7 @@ export const clients = sqliteTable("clients", { pubKey: text("pubKey"), olmId: text("olmId"), // to lock it to a specific olm optionally subnet: text("subnet").notNull(), - megabytesIn: integer("bytesIn"), - megabytesOut: integer("bytesOut"), - lastBandwidthUpdate: text("lastBandwidthUpdate"), - lastPing: integer("lastPing"), + type: text("type").notNull(), // "olm" online: integer("online", { mode: "boolean" }).notNull().default(false), // endpoint: text("endpoint"), @@ -414,6 +407,42 @@ export const clients = sqliteTable("clients", { >() }); +export const sitePing = sqliteTable("sitePing", { + siteId: integer("siteId") + .primaryKey() + .references(() => sites.siteId, { onDelete: "cascade" }) + .notNull(), + lastPing: integer("lastPing") +}); + +export const siteBandwidth = sqliteTable("siteBandwidth", { + siteId: integer("siteId") + .primaryKey() + .references(() => sites.siteId, { onDelete: "cascade" }) + .notNull(), + megabytesIn: integer("bytesIn").default(0), + megabytesOut: integer("bytesOut").default(0), + lastBandwidthUpdate: integer("lastBandwidthUpdate") // unix epoch +}); + +export const clientPing = sqliteTable("clientPing", { + clientId: integer("clientId") + .primaryKey() + .references(() => clients.clientId, { onDelete: "cascade" }) + .notNull(), + lastPing: integer("lastPing") +}); + +export const clientBandwidth = sqliteTable("clientBandwidth", { + clientId: integer("clientId") + .primaryKey() + .references(() => clients.clientId, { onDelete: "cascade" }) + .notNull(), + megabytesIn: integer("bytesIn"), + megabytesOut: integer("bytesOut"), + lastBandwidthUpdate: integer("lastBandwidthUpdate") // unix epoch +}); + export const clientSitesAssociationsCache = sqliteTable( "clientSitesAssociationsCache", { @@ -1209,3 +1238,7 @@ export type DeviceWebAuthCode = InferSelectModel; export type RoundTripMessageTracker = InferSelectModel< typeof roundTripMessageTracker >; +export type SitePing = typeof sitePing.$inferSelect; +export type SiteBandwidth = typeof siteBandwidth.$inferSelect; +export type ClientPing = typeof clientPing.$inferSelect; +export type ClientBandwidth = typeof clientBandwidth.$inferSelect; diff --git a/server/lib/telemetry.ts b/server/lib/telemetry.ts index fda59f394..bc9de372a 100644 --- a/server/lib/telemetry.ts +++ b/server/lib/telemetry.ts @@ -3,7 +3,7 @@ import config from "./config"; import { getHostMeta } from "./hostMeta"; import logger from "@server/logger"; import { apiKeys, db, roles, siteResources } from "@server/db"; -import { sites, users, orgs, resources, clients, idp } from "@server/db"; +import { sites, users, orgs, resources, clients, idp, siteBandwidth } from "@server/db"; import { eq, count, notInArray, and, isNotNull, isNull } from "drizzle-orm"; import { APP_VERSION } from "./consts"; import crypto from "crypto"; @@ -150,12 +150,13 @@ class TelemetryClient { const siteDetails = await db .select({ siteName: sites.name, - megabytesIn: sites.megabytesIn, - megabytesOut: sites.megabytesOut, + megabytesIn: siteBandwidth.megabytesIn, + megabytesOut: siteBandwidth.megabytesOut, type: sites.type, online: sites.online }) - .from(sites); + .from(sites) + .leftJoin(siteBandwidth, eq(siteBandwidth.siteId, sites.siteId)); const supporterKey = config.getSupporterData(); diff --git a/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts b/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts index 0305e7f1b..3ab593890 100644 --- a/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts +++ b/server/private/routers/billing/hooks/handleSubscriptionUpdated.ts @@ -18,10 +18,11 @@ import { subscriptionItems, usage, sites, + siteBandwidth, customers, orgs } from "@server/db"; -import { eq, and } from "drizzle-orm"; +import { eq, and, inArray } from "drizzle-orm"; import logger from "@server/logger"; import { getFeatureIdByMetricId, getFeatureIdByPriceId } from "@server/lib/billing/features"; import stripe from "#private/lib/stripe"; @@ -253,14 +254,19 @@ export async function handleSubscriptionUpdated( ); } - // Also reset the sites to 0 + // Also reset the site bandwidth to 0 await trx - .update(sites) + .update(siteBandwidth) .set({ megabytesIn: 0, megabytesOut: 0 }) - .where(eq(sites.orgId, orgId)); + .where( + inArray( + siteBandwidth.siteId, + trx.select({ siteId: sites.siteId }).from(sites).where(eq(sites.orgId, orgId)) + ) + ); }); } } diff --git a/server/routers/client/listClients.ts b/server/routers/client/listClients.ts index 0bf798509..8ce6a1218 100644 --- a/server/routers/client/listClients.ts +++ b/server/routers/client/listClients.ts @@ -1,4 +1,5 @@ import { + clientBandwidth, clients, clientSitesAssociationsCache, currentFingerprint, @@ -180,8 +181,8 @@ function queryClientsBase() { name: clients.name, pubKey: clients.pubKey, subnet: clients.subnet, - megabytesIn: clients.megabytesIn, - megabytesOut: clients.megabytesOut, + megabytesIn: clientBandwidth.megabytesIn, + megabytesOut: clientBandwidth.megabytesOut, orgName: orgs.name, type: clients.type, online: clients.online, @@ -200,7 +201,8 @@ function queryClientsBase() { .leftJoin(orgs, eq(clients.orgId, orgs.orgId)) .leftJoin(olms, eq(clients.clientId, olms.clientId)) .leftJoin(users, eq(clients.userId, users.userId)) - .leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId)); + .leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId)) + .leftJoin(clientBandwidth, eq(clientBandwidth.clientId, clients.clientId)); } async function getSiteAssociations(clientIds: number[]) { @@ -367,9 +369,15 @@ export async function listClients( .offset(pageSize * (page - 1)) .orderBy( sort_by - ? order === "asc" - ? asc(clients[sort_by]) - : desc(clients[sort_by]) + ? (() => { + const field = + sort_by === "megabytesIn" + ? clientBandwidth.megabytesIn + : sort_by === "megabytesOut" + ? clientBandwidth.megabytesOut + : clients.name; + return order === "asc" ? asc(field) : desc(field); + })() : asc(clients.name) ); diff --git a/server/routers/client/listUserDevices.ts b/server/routers/client/listUserDevices.ts index 0ae31165a..dd5321507 100644 --- a/server/routers/client/listUserDevices.ts +++ b/server/routers/client/listUserDevices.ts @@ -1,5 +1,6 @@ import { build } from "@server/build"; import { + clientBandwidth, clients, currentFingerprint, db, @@ -211,8 +212,8 @@ function queryUserDevicesBase() { name: clients.name, pubKey: clients.pubKey, subnet: clients.subnet, - megabytesIn: clients.megabytesIn, - megabytesOut: clients.megabytesOut, + megabytesIn: clientBandwidth.megabytesIn, + megabytesOut: clientBandwidth.megabytesOut, orgName: orgs.name, type: clients.type, online: clients.online, @@ -239,7 +240,8 @@ function queryUserDevicesBase() { .leftJoin(orgs, eq(clients.orgId, orgs.orgId)) .leftJoin(olms, eq(clients.clientId, olms.clientId)) .leftJoin(users, eq(clients.userId, users.userId)) - .leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId)); + .leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId)) + .leftJoin(clientBandwidth, eq(clientBandwidth.clientId, clients.clientId)); } type OlmWithUpdateAvailable = Awaited< @@ -427,9 +429,15 @@ export async function listUserDevices( .offset(pageSize * (page - 1)) .orderBy( sort_by - ? order === "asc" - ? asc(clients[sort_by]) - : desc(clients[sort_by]) + ? (() => { + const field = + sort_by === "megabytesIn" + ? clientBandwidth.megabytesIn + : sort_by === "megabytesOut" + ? clientBandwidth.megabytesOut + : clients.name; + return order === "asc" ? asc(field) : desc(field); + })() : asc(clients.clientId) ); diff --git a/server/routers/gerbil/receiveBandwidth.ts b/server/routers/gerbil/receiveBandwidth.ts index dcd897471..7a218754a 100644 --- a/server/routers/gerbil/receiveBandwidth.ts +++ b/server/routers/gerbil/receiveBandwidth.ts @@ -122,7 +122,7 @@ export async function flushSiteBandwidthToDb(): Promise { const snapshot = accumulator; accumulator = new Map(); - const currentTime = new Date().toISOString(); + const currentEpoch = Math.floor(Date.now() / 1000); // Sort by publicKey for consistent lock ordering across concurrent // writers — deadlock-prevention strategy. @@ -157,33 +157,52 @@ export async function flushSiteBandwidthToDb(): Promise { orgId: string; pubKey: string; }>(sql` - UPDATE sites - SET - "bytesOut" = COALESCE("bytesOut", 0) + ${bytesIn}, - "bytesIn" = COALESCE("bytesIn", 0) + ${bytesOut}, - "lastBandwidthUpdate" = ${currentTime} - WHERE "pubKey" = ${publicKey} - RETURNING "orgId", "pubKey" + WITH upsert AS ( + INSERT INTO "siteBandwidth" ("siteId", "bytesIn", "bytesOut", "lastBandwidthUpdate") + SELECT s."siteId", ${bytesIn}, ${bytesOut}, ${currentEpoch} + FROM "sites" s WHERE s."pubKey" = ${publicKey} + ON CONFLICT ("siteId") DO UPDATE SET + "bytesIn" = COALESCE("siteBandwidth"."bytesIn", 0) + EXCLUDED."bytesIn", + "bytesOut" = COALESCE("siteBandwidth"."bytesOut", 0) + EXCLUDED."bytesOut", + "lastBandwidthUpdate" = EXCLUDED."lastBandwidthUpdate" + RETURNING "siteId" + ) + SELECT u."siteId", s."orgId", s."pubKey" + FROM upsert u + INNER JOIN "sites" s ON s."siteId" = u."siteId" `); results.push(...result); } return results; } - // PostgreSQL: batch UPDATE … FROM (VALUES …) — single round-trip per chunk. + // PostgreSQL: batch UPSERT via CTE — single round-trip per chunk. const valuesList = chunk.map(([publicKey, { bytesIn, bytesOut }]) => sql`(${publicKey}::text, ${bytesIn}::real, ${bytesOut}::real)` ); const valuesClause = sql.join(valuesList, sql`, `); return dbQueryRows<{ orgId: string; pubKey: string }>(sql` - UPDATE sites - SET - "bytesOut" = COALESCE("bytesOut", 0) + v.bytes_in, - "bytesIn" = COALESCE("bytesIn", 0) + v.bytes_out, - "lastBandwidthUpdate" = ${currentTime} - FROM (VALUES ${valuesClause}) AS v(pub_key, bytes_in, bytes_out) - WHERE sites."pubKey" = v.pub_key - RETURNING sites."orgId" AS "orgId", sites."pubKey" AS "pubKey" + WITH vals(pub_key, bytes_in, bytes_out) AS ( + VALUES ${valuesClause} + ), + site_lookup AS ( + SELECT s."siteId", s."orgId", s."pubKey", v.bytes_in, v.bytes_out + FROM vals v + INNER JOIN "sites" s ON s."pubKey" = v.pub_key + ), + upsert AS ( + INSERT INTO "siteBandwidth" ("siteId", "bytesIn", "bytesOut", "lastBandwidthUpdate") + SELECT sl."siteId", sl.bytes_in, sl.bytes_out, ${currentEpoch}::integer + FROM site_lookup sl + ON CONFLICT ("siteId") DO UPDATE SET + "bytesIn" = COALESCE("siteBandwidth"."bytesIn", 0) + EXCLUDED."bytesIn", + "bytesOut" = COALESCE("siteBandwidth"."bytesOut", 0) + EXCLUDED."bytesOut", + "lastBandwidthUpdate" = EXCLUDED."lastBandwidthUpdate" + RETURNING "siteId" + ) + SELECT u."siteId", s."orgId", s."pubKey" + FROM upsert u + INNER JOIN "sites" s ON s."siteId" = u."siteId" `); }, `flush bandwidth chunk [${i}–${chunkEnd}]`); } catch (error) { diff --git a/server/routers/newt/handleNewtPingMessage.ts b/server/routers/newt/handleNewtPingMessage.ts index 32f665758..649539589 100644 --- a/server/routers/newt/handleNewtPingMessage.ts +++ b/server/routers/newt/handleNewtPingMessage.ts @@ -1,11 +1,11 @@ -import { db, newts, sites, targetHealthCheck, targets } from "@server/db"; +import { db, newts, sites, targetHealthCheck, targets, sitePing, siteBandwidth } from "@server/db"; import { hasActiveConnections, getClientConfigVersion } from "#dynamic/routers/ws"; import { MessageHandler } from "@server/routers/ws"; import { Newt } from "@server/db"; -import { eq, lt, isNull, and, or, ne, not } from "drizzle-orm"; +import { eq, lt, isNull, and, or, ne } from "drizzle-orm"; import logger from "@server/logger"; import { sendNewtSyncMessage } from "./sync"; import { recordPing } from "./pingAccumulator"; @@ -41,17 +41,18 @@ export const startNewtOfflineChecker = (): void => { .select({ siteId: sites.siteId, newtId: newts.newtId, - lastPing: sites.lastPing + lastPing: sitePing.lastPing }) .from(sites) .innerJoin(newts, eq(newts.siteId, sites.siteId)) + .leftJoin(sitePing, eq(sitePing.siteId, sites.siteId)) .where( and( eq(sites.online, true), eq(sites.type, "newt"), or( - lt(sites.lastPing, twoMinutesAgo), - isNull(sites.lastPing) + lt(sitePing.lastPing, twoMinutesAgo), + isNull(sitePing.lastPing) ) ) ); @@ -112,15 +113,11 @@ export const startNewtOfflineChecker = (): void => { .select({ siteId: sites.siteId, online: sites.online, - lastBandwidthUpdate: sites.lastBandwidthUpdate + lastBandwidthUpdate: siteBandwidth.lastBandwidthUpdate }) .from(sites) - .where( - and( - eq(sites.type, "wireguard"), - not(isNull(sites.lastBandwidthUpdate)) - ) - ); + .innerJoin(siteBandwidth, eq(siteBandwidth.siteId, sites.siteId)) + .where(eq(sites.type, "wireguard")); const wireguardOfflineThreshold = Math.floor( (Date.now() - OFFLINE_THRESHOLD_BANDWIDTH_MS) / 1000 @@ -128,12 +125,7 @@ export const startNewtOfflineChecker = (): void => { // loop over each one. If its offline and there is a new update then mark it online. If its online and there is no update then mark it offline for (const site of allWireguardSites) { - const lastBandwidthUpdate = - new Date(site.lastBandwidthUpdate!).getTime() / 1000; - if ( - lastBandwidthUpdate < wireguardOfflineThreshold && - site.online - ) { + if ((site.lastBandwidthUpdate ?? 0) < wireguardOfflineThreshold && site.online) { logger.info( `Marking wireguard site ${site.siteId} offline: no bandwidth update in over ${OFFLINE_THRESHOLD_BANDWIDTH_MS / 60000} minutes` ); @@ -142,10 +134,7 @@ export const startNewtOfflineChecker = (): void => { .update(sites) .set({ online: false }) .where(eq(sites.siteId, site.siteId)); - } else if ( - lastBandwidthUpdate >= wireguardOfflineThreshold && - !site.online - ) { + } else if ((site.lastBandwidthUpdate ?? 0) >= wireguardOfflineThreshold && !site.online) { logger.info( `Marking wireguard site ${site.siteId} online: recent bandwidth update` ); diff --git a/server/routers/newt/handleReceiveBandwidthMessage.ts b/server/routers/newt/handleReceiveBandwidthMessage.ts index f086333e7..19730d1dc 100644 --- a/server/routers/newt/handleReceiveBandwidthMessage.ts +++ b/server/routers/newt/handleReceiveBandwidthMessage.ts @@ -1,6 +1,5 @@ -import { db } from "@server/db"; +import { db, clients, clientBandwidth } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; -import { clients } from "@server/db"; import { eq, sql } from "drizzle-orm"; import logger from "@server/logger"; @@ -85,7 +84,7 @@ export async function flushBandwidthToDb(): Promise { const snapshot = accumulator; accumulator = new Map(); - const currentTime = new Date().toISOString(); + const currentEpoch = Math.floor(Date.now() / 1000); // Sort by publicKey for consistent lock ordering across concurrent // writers — this is the same deadlock-prevention strategy used in the @@ -101,19 +100,37 @@ export async function flushBandwidthToDb(): Promise { for (const [publicKey, { bytesIn, bytesOut }] of sortedEntries) { try { await withDeadlockRetry(async () => { - // Use atomic SQL increment to avoid the SELECT-then-UPDATE - // anti-pattern and the races it would introduce. + // Find clientId by pubKey + const [clientRow] = await db + .select({ clientId: clients.clientId }) + .from(clients) + .where(eq(clients.pubKey, publicKey)) + .limit(1); + + if (!clientRow) { + logger.warn(`No client found for pubKey ${publicKey}, skipping`); + return; + } + await db - .update(clients) - .set({ + .insert(clientBandwidth) + .values({ + clientId: clientRow.clientId, // Note: bytesIn from peer goes to megabytesOut (data // sent to client) and bytesOut from peer goes to // megabytesIn (data received from client). - megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`, - megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`, - lastBandwidthUpdate: currentTime + megabytesOut: bytesIn, + megabytesIn: bytesOut, + lastBandwidthUpdate: currentEpoch }) - .where(eq(clients.pubKey, publicKey)); + .onConflictDoUpdate({ + target: clientBandwidth.clientId, + set: { + megabytesOut: sql`COALESCE(${clientBandwidth.megabytesOut}, 0) + ${bytesIn}`, + megabytesIn: sql`COALESCE(${clientBandwidth.megabytesIn}, 0) + ${bytesOut}`, + lastBandwidthUpdate: currentEpoch + } + }); }, `flush bandwidth for client ${publicKey}`); } catch (error) { logger.error( diff --git a/server/routers/newt/pingAccumulator.ts b/server/routers/newt/pingAccumulator.ts index fe2cde216..32a865bac 100644 --- a/server/routers/newt/pingAccumulator.ts +++ b/server/routers/newt/pingAccumulator.ts @@ -1,6 +1,6 @@ import { db } from "@server/db"; -import { sites, clients, olms } from "@server/db"; -import { inArray } from "drizzle-orm"; +import { sites, clients, olms, sitePing, clientPing } from "@server/db"; +import { inArray, sql } from "drizzle-orm"; import logger from "@server/logger"; /** @@ -81,11 +81,8 @@ export function recordClientPing( /** * Flush all accumulated site pings to the database. * - * Each batch of up to BATCH_SIZE rows is written with a **single** UPDATE - * statement. We use the maximum timestamp across the batch so that `lastPing` - * reflects the most recent ping seen for any site in the group. This avoids - * the multi-statement transaction that previously created additional - * row-lock ordering hazards. + * For each batch: first upserts individual per-site timestamps into + * `sitePing`, then bulk-updates `sites.online = true`. */ async function flushSitePingsToDb(): Promise { if (pendingSitePings.size === 0) { @@ -103,20 +100,25 @@ async function flushSitePingsToDb(): Promise { for (let i = 0; i < entries.length; i += BATCH_SIZE) { const batch = entries.slice(i, i + BATCH_SIZE); - // Use the latest timestamp in the batch so that `lastPing` always - // moves forward. Using a single timestamp for the whole batch means - // we only ever need one UPDATE statement (no transaction). - const maxTimestamp = Math.max(...batch.map(([, ts]) => ts)); const siteIds = batch.map(([id]) => id); try { await withRetry(async () => { + const rows = batch.map(([siteId, ts]) => ({ siteId, lastPing: ts })); + + // Step 1: Upsert ping timestamps into sitePing + await db + .insert(sitePing) + .values(rows) + .onConflictDoUpdate({ + target: sitePing.siteId, + set: { lastPing: sql`excluded."lastPing"` } + }); + + // Step 2: Update online status on sites await db .update(sites) - .set({ - online: true, - lastPing: maxTimestamp - }) + .set({ online: true }) .where(inArray(sites.siteId, siteIds)); }, "flushSitePingsToDb"); } catch (error) { @@ -139,7 +141,8 @@ async function flushSitePingsToDb(): Promise { /** * Flush all accumulated client (OLM) pings to the database. * - * Same single-UPDATE-per-batch approach as `flushSitePingsToDb`. + * For each batch: first upserts individual per-client timestamps into + * `clientPing`, then bulk-updates `clients.online = true, archived = false`. */ async function flushClientPingsToDb(): Promise { if (pendingClientPings.size === 0 && pendingOlmArchiveResets.size === 0) { @@ -161,18 +164,25 @@ async function flushClientPingsToDb(): Promise { for (let i = 0; i < entries.length; i += BATCH_SIZE) { const batch = entries.slice(i, i + BATCH_SIZE); - const maxTimestamp = Math.max(...batch.map(([, ts]) => ts)); const clientIds = batch.map(([id]) => id); try { await withRetry(async () => { + const rows = batch.map(([clientId, ts]) => ({ clientId, lastPing: ts })); + + // Step 1: Upsert ping timestamps into clientPing + await db + .insert(clientPing) + .values(rows) + .onConflictDoUpdate({ + target: clientPing.clientId, + set: { lastPing: sql`excluded."lastPing"` } + }); + + // Step 2: Update online + unarchive on clients await db .update(clients) - .set({ - lastPing: maxTimestamp, - online: true, - archived: false - }) + .set({ online: true, archived: false }) .where(inArray(clients.clientId, clientIds)); }, "flushClientPingsToDb"); } catch (error) { diff --git a/server/routers/olm/handleOlmPingMessage.ts b/server/routers/olm/handleOlmPingMessage.ts index 0f520b234..8f770a685 100644 --- a/server/routers/olm/handleOlmPingMessage.ts +++ b/server/routers/olm/handleOlmPingMessage.ts @@ -1,8 +1,8 @@ import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws"; import { db } from "@server/db"; import { MessageHandler } from "@server/routers/ws"; -import { clients, olms, Olm } from "@server/db"; -import { eq, lt, isNull, and, or } from "drizzle-orm"; +import { clients, olms, Olm, clientPing } from "@server/db"; +import { eq, lt, isNull, and, or, inArray } from "drizzle-orm"; import { recordClientPing } from "@server/routers/newt/pingAccumulator"; import logger from "@server/logger"; import { validateSessionToken } from "@server/auth/sessions/app"; @@ -37,21 +37,33 @@ export const startOlmOfflineChecker = (): void => { // TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING // Find clients that haven't pinged in the last 2 minutes and mark them as offline - const offlineClients = await db - .update(clients) - .set({ online: false }) + const staleClientRows = await db + .select({ + clientId: clients.clientId, + olmId: clients.olmId, + lastPing: clientPing.lastPing + }) + .from(clients) + .leftJoin(clientPing, eq(clientPing.clientId, clients.clientId)) .where( and( eq(clients.online, true), or( - lt(clients.lastPing, twoMinutesAgo), - isNull(clients.lastPing) + lt(clientPing.lastPing, twoMinutesAgo), + isNull(clientPing.lastPing) ) ) - ) - .returning(); + ); - for (const offlineClient of offlineClients) { + if (staleClientRows.length > 0) { + const staleClientIds = staleClientRows.map((c) => c.clientId); + await db + .update(clients) + .set({ online: false }) + .where(inArray(clients.clientId, staleClientIds)); + } + + for (const offlineClient of staleClientRows) { logger.info( `Kicking offline olm client ${offlineClient.clientId} due to inactivity` ); diff --git a/server/routers/org/resetOrgBandwidth.ts b/server/routers/org/resetOrgBandwidth.ts index b98e2e406..453ffa32b 100644 --- a/server/routers/org/resetOrgBandwidth.ts +++ b/server/routers/org/resetOrgBandwidth.ts @@ -1,7 +1,7 @@ import { NextFunction, Request, Response } from "express"; import { z } from "zod"; -import { db, sites } from "@server/db"; -import { eq } from "drizzle-orm"; +import { db, sites, siteBandwidth } from "@server/db"; +import { eq, inArray } from "drizzle-orm"; import response from "@server/lib/response"; import HttpCode from "@server/types/HttpCode"; import createHttpError from "http-errors"; @@ -60,12 +60,17 @@ export async function resetOrgBandwidth( } await db - .update(sites) + .update(siteBandwidth) .set({ megabytesIn: 0, megabytesOut: 0 }) - .where(eq(sites.orgId, orgId)); + .where( + inArray( + siteBandwidth.siteId, + db.select({ siteId: sites.siteId }).from(sites).where(eq(sites.orgId, orgId)) + ) + ); return response(res, { data: {}, diff --git a/server/routers/site/listSites.ts b/server/routers/site/listSites.ts index 6f085d74d..e14880b52 100644 --- a/server/routers/site/listSites.ts +++ b/server/routers/site/listSites.ts @@ -6,6 +6,7 @@ import { remoteExitNodes, roleSites, sites, + siteBandwidth, userSites } from "@server/db"; import cache from "#dynamic/lib/cache"; @@ -155,8 +156,8 @@ function querySitesBase() { name: sites.name, pubKey: sites.pubKey, subnet: sites.subnet, - megabytesIn: sites.megabytesIn, - megabytesOut: sites.megabytesOut, + megabytesIn: siteBandwidth.megabytesIn, + megabytesOut: siteBandwidth.megabytesOut, orgName: orgs.name, type: sites.type, online: sites.online, @@ -175,7 +176,8 @@ function querySitesBase() { .leftJoin( remoteExitNodes, eq(remoteExitNodes.exitNodeId, sites.exitNodeId) - ); + ) + .leftJoin(siteBandwidth, eq(siteBandwidth.siteId, sites.siteId)); } type SiteWithUpdateAvailable = Awaited>[0] & { @@ -299,9 +301,15 @@ export async function listSites( .offset(pageSize * (page - 1)) .orderBy( sort_by - ? order === "asc" - ? asc(sites[sort_by]) - : desc(sites[sort_by]) + ? (() => { + const field = + sort_by === "megabytesIn" + ? siteBandwidth.megabytesIn + : sort_by === "megabytesOut" + ? siteBandwidth.megabytesOut + : sites.name; + return order === "asc" ? asc(field) : desc(field); + })() : asc(sites.name) ); diff --git a/server/setup/migrationsPg.ts b/server/setup/migrationsPg.ts index 9ba0b9767..992cc2583 100644 --- a/server/setup/migrationsPg.ts +++ b/server/setup/migrationsPg.ts @@ -22,6 +22,7 @@ import m13 from "./scriptsPg/1.15.3"; import m14 from "./scriptsPg/1.15.4"; import m15 from "./scriptsPg/1.16.0"; import m16 from "./scriptsPg/1.17.0"; +import m17 from "./scriptsPg/1.18.0"; // THIS CANNOT IMPORT ANYTHING FROM THE SERVER // EXCEPT FOR THE DATABASE AND THE SCHEMA @@ -43,7 +44,8 @@ const migrations = [ { version: "1.15.3", run: m13 }, { version: "1.15.4", run: m14 }, { version: "1.16.0", run: m15 }, - { version: "1.17.0", run: m16 } + { version: "1.17.0", run: m16 }, + { version: "1.18.0", run: m17 } // Add new migrations here as they are created ] as { version: string; diff --git a/server/setup/migrationsSqlite.ts b/server/setup/migrationsSqlite.ts index 45a29ec29..c32437aec 100644 --- a/server/setup/migrationsSqlite.ts +++ b/server/setup/migrationsSqlite.ts @@ -40,6 +40,7 @@ import m34 from "./scriptsSqlite/1.15.3"; import m35 from "./scriptsSqlite/1.15.4"; import m36 from "./scriptsSqlite/1.16.0"; import m37 from "./scriptsSqlite/1.17.0"; +import m38 from "./scriptsSqlite/1.18.0"; // THIS CANNOT IMPORT ANYTHING FROM THE SERVER // EXCEPT FOR THE DATABASE AND THE SCHEMA @@ -77,7 +78,8 @@ const migrations = [ { version: "1.15.3", run: m34 }, { version: "1.15.4", run: m35 }, { version: "1.16.0", run: m36 }, - { version: "1.17.0", run: m37 } + { version: "1.17.0", run: m37 }, + { version: "1.18.0", run: m38 } // Add new migrations here as they are created ] as const;