Compare commits

..

5 Commits

Author SHA1 Message Date
Owen Schwartz
bdc45887f9 Add chainId to dedup messages (#2737)
* ChainId send through on sensitive messages
2026-03-29 12:08:29 -07:00
Owen Schwartz
6d7a19b0a0 Merge pull request #2716 from fosrl/patch-1
Add typecasts
2026-03-25 22:12:59 -07:00
Owen
6b3a6fa380 Add typecasts 2026-03-25 22:11:56 -07:00
Owen Schwartz
e2a65b4b74 Merge pull request #2715 from fosrl/batch-band
Batch set bandwidth
2026-03-25 21:54:44 -07:00
Owen
1f01108b62 Batch set bandwidth 2026-03-25 21:53:20 -07:00
5 changed files with 97 additions and 86 deletions

View File

@@ -1,6 +1,5 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { eq, sql } from "drizzle-orm"; import { sql } from "drizzle-orm";
import { sites } from "@server/db";
import { db } from "@server/db"; import { db } from "@server/db";
import logger from "@server/logger"; import logger from "@server/logger";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
@@ -31,7 +30,10 @@ const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50; const BASE_DELAY_MS = 50;
// How often to flush accumulated bandwidth data to the database // How often to flush accumulated bandwidth data to the database
const FLUSH_INTERVAL_MS = 30_000; // 30 seconds const FLUSH_INTERVAL_MS = 300_000; // 300 seconds
// Maximum number of sites to include in a single batch UPDATE statement
const BATCH_CHUNK_SIZE = 250;
// In-memory accumulator: publicKey -> AccumulatorEntry // In-memory accumulator: publicKey -> AccumulatorEntry
let accumulator = new Map<string, AccumulatorEntry>(); let accumulator = new Map<string, AccumulatorEntry>();
@@ -75,13 +77,33 @@ async function withDeadlockRetry<T>(
} }
} }
/**
* Execute a raw SQL query that returns rows, in a way that works across both
* the PostgreSQL driver (which exposes `execute`) and the SQLite driver (which
* exposes `all`). Drizzle's typed query builder doesn't support bulk
* UPDATE … FROM (VALUES …) natively, so we drop to raw SQL here.
*/
async function dbQueryRows<T extends Record<string, unknown>>(
query: Parameters<(typeof sql)["join"]>[0][number]
): Promise<T[]> {
const anyDb = db as any;
if (typeof anyDb.execute === "function") {
// PostgreSQL (node-postgres via Drizzle) — returns { rows: [...] } or an array
const result = await anyDb.execute(query);
return (Array.isArray(result) ? result : (result.rows ?? [])) as T[];
}
// SQLite (better-sqlite3 via Drizzle) — returns an array directly
return (await anyDb.all(query)) as T[];
}
/** /**
* Flush all accumulated site bandwidth data to the database. * Flush all accumulated site bandwidth data to the database.
* *
* Swaps out the accumulator before writing so that any bandwidth messages * Swaps out the accumulator before writing so that any bandwidth messages
* received during the flush are captured in the new accumulator rather than * received during the flush are captured in the new accumulator rather than
* being lost or causing contention. Entries that fail to write are re-queued * being lost or causing contention. Sites are updated in chunks via a single
* back into the accumulator so they will be retried on the next flush. * batch UPDATE per chunk. Failed chunks are discarded — exact per-flush
* accuracy is not critical and re-queuing is not worth the added complexity.
* *
* This function is exported so that the application's graceful-shutdown * This function is exported so that the application's graceful-shutdown
* cleanup handler can call it before the process exits. * cleanup handler can call it before the process exits.
@@ -108,76 +130,76 @@ export async function flushSiteBandwidthToDb(): Promise<void> {
`Flushing accumulated bandwidth data for ${sortedEntries.length} site(s) to the database` `Flushing accumulated bandwidth data for ${sortedEntries.length} site(s) to the database`
); );
// Aggregate billing usage by org, collected during the DB update loop. // Build a lookup so post-processing can reach each entry by publicKey.
const snapshotMap = new Map(sortedEntries);
// Aggregate billing usage by org across all chunks.
const orgUsageMap = new Map<string, number>(); const orgUsageMap = new Map<string, number>();
for (const [publicKey, { bytesIn, bytesOut, exitNodeId, calcUsage }] of sortedEntries) { // Process in chunks so individual queries stay at a reasonable size.
try { for (let i = 0; i < sortedEntries.length; i += BATCH_CHUNK_SIZE) {
const updatedSite = await withDeadlockRetry(async () => { const chunk = sortedEntries.slice(i, i + BATCH_CHUNK_SIZE);
const [result] = await db const chunkEnd = i + chunk.length - 1;
.update(sites)
.set({
megabytesOut: sql`COALESCE(${sites.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${sites.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime,
})
.where(eq(sites.pubKey, publicKey))
.returning({
orgId: sites.orgId,
siteId: sites.siteId
});
return result;
}, `flush bandwidth for site ${publicKey}`);
if (updatedSite) { // Build a parameterised VALUES list: (pubKey, bytesIn, bytesOut), ...
if (exitNodeId) { // Both PostgreSQL and SQLite (≥ 3.33.0, which better-sqlite3 bundles)
const notAllowed = await checkExitNodeOrg( // support UPDATE … FROM (VALUES …), letting us update the whole chunk
exitNodeId, // in a single query instead of N individual round-trips.
updatedSite.orgId const valuesList = chunk.map(([publicKey, { bytesIn, bytesOut }]) =>
sql`(${publicKey}::text, ${bytesIn}::real, ${bytesOut}::real)`
); );
const valuesClause = sql.join(valuesList, sql`, `);
let rows: { orgId: string; pubKey: string }[] = [];
try {
rows = await withDeadlockRetry(async () => {
return dbQueryRows<{ orgId: string; pubKey: string }>(sql`
UPDATE sites
SET
"bytesOut" = COALESCE("bytesOut", 0) + v.bytes_in,
"bytesIn" = COALESCE("bytesIn", 0) + v.bytes_out,
"lastBandwidthUpdate" = ${currentTime}
FROM (VALUES ${valuesClause}) AS v(pub_key, bytes_in, bytes_out)
WHERE sites."pubKey" = v.pub_key
RETURNING sites."orgId" AS "orgId", sites."pubKey" AS "pubKey"
`);
}, `flush bandwidth chunk [${i}${chunkEnd}]`);
} catch (error) {
logger.error(
`Failed to flush bandwidth chunk [${i}${chunkEnd}], discarding ${chunk.length} site(s):`,
error
);
// Discard the chunk — exact per-flush accuracy is not critical.
continue;
}
// Collect billing usage from the returned rows.
for (const { orgId, pubKey } of rows) {
const entry = snapshotMap.get(pubKey);
if (!entry) continue;
const { bytesIn, bytesOut, exitNodeId, calcUsage } = entry;
if (exitNodeId) {
const notAllowed = await checkExitNodeOrg(exitNodeId, orgId);
if (notAllowed) { if (notAllowed) {
logger.warn( logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}` `Exit node ${exitNodeId} is not allowed for org ${orgId}`
); );
// Skip usage tracking for this site but continue
// processing the rest.
continue; continue;
} }
} }
if (calcUsage) { if (calcUsage) {
const totalBandwidth = bytesIn + bytesOut; const current = orgUsageMap.get(orgId) ?? 0;
const current = orgUsageMap.get(updatedSite.orgId) ?? 0; orgUsageMap.set(orgId, current + bytesIn + bytesOut);
orgUsageMap.set(updatedSite.orgId, current + totalBandwidth);
}
}
} catch (error) {
logger.error(
`Failed to flush bandwidth for site ${publicKey}:`,
error
);
// Re-queue the failed entry so it is retried on the next flush
// rather than silently dropped.
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
} else {
accumulator.set(publicKey, {
bytesIn,
bytesOut,
exitNodeId,
calcUsage
});
} }
} }
} }
// Process billing usage updates outside the site-update loop to keep // Process billing usage updates after all chunks are written.
// lock scope small and concerns separated.
if (orgUsageMap.size > 0) { if (orgUsageMap.size > 0) {
// Sort org IDs for consistent lock ordering.
const sortedOrgIds = [...orgUsageMap.keys()].sort(); const sortedOrgIds = [...orgUsageMap.keys()].sort();
for (const orgId of sortedOrgIds) { for (const orgId of sortedOrgIds) {

View File

@@ -8,13 +8,6 @@ import { sendToExitNode } from "#dynamic/lib/exitNodes";
import { buildClientConfigurationForNewtClient } from "./buildConfiguration"; import { buildClientConfigurationForNewtClient } from "./buildConfiguration";
import { canCompress } from "@server/lib/clientVersionChecks"; import { canCompress } from "@server/lib/clientVersionChecks";
const inputSchema = z.object({
publicKey: z.string(),
port: z.int().positive()
});
type Input = z.infer<typeof inputSchema>;
export const handleGetConfigMessage: MessageHandler = async (context) => { export const handleGetConfigMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context; const { message, client, sendToClient } = context;
const newt = client as Newt; const newt = client as Newt;
@@ -33,16 +26,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
return; return;
} }
const parsed = inputSchema.safeParse(message.data); const { publicKey, port, chainId } = message.data;
if (!parsed.success) {
logger.error(
"handleGetConfigMessage: Invalid input: " +
fromError(parsed.error).toString()
);
return;
}
const { publicKey, port } = message.data as Input;
const siteId = newt.siteId; const siteId = newt.siteId;
// Get the current site data // Get the current site data
@@ -133,7 +117,8 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
data: { data: {
ipAddress: site.address, ipAddress: site.address,
peers, peers,
targets targets,
chainId: chainId
} }
}, },
options: { options: {

View File

@@ -33,7 +33,7 @@ export const handleNewtPingRequestMessage: MessageHandler = async (context) => {
return; return;
} }
const { noCloud } = message.data; const { noCloud, chainId } = message.data;
const exitNodesList = await listExitNodes( const exitNodesList = await listExitNodes(
site.orgId, site.orgId,
@@ -98,7 +98,8 @@ export const handleNewtPingRequestMessage: MessageHandler = async (context) => {
message: { message: {
type: "newt/ping/exitNodes", type: "newt/ping/exitNodes",
data: { data: {
exitNodes: filteredExitNodes exitNodes: filteredExitNodes,
chainId: chainId
} }
}, },
broadcast: false, // Send to all clients broadcast: false, // Send to all clients

View File

@@ -43,7 +43,7 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
const siteId = newt.siteId; const siteId = newt.siteId;
const { publicKey, pingResults, newtVersion, backwardsCompatible } = const { publicKey, pingResults, newtVersion, backwardsCompatible, chainId } =
message.data; message.data;
if (!publicKey) { if (!publicKey) {
logger.warn("Public key not provided"); logger.warn("Public key not provided");
@@ -211,7 +211,8 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
udp: udpTargets, udp: udpTargets,
tcp: tcpTargets tcp: tcpTargets
}, },
healthCheckTargets: validHealthCheckTargets healthCheckTargets: validHealthCheckTargets,
chainId: chainId
} }
}, },
options: { options: {

View File

@@ -41,7 +41,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
orgId, orgId,
userToken, userToken,
fingerprint, fingerprint,
postures postures,
chainId
} = message.data; } = message.data;
if (!olm.clientId) { if (!olm.clientId) {
@@ -293,7 +294,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
data: { data: {
sites: siteConfigurations, sites: siteConfigurations,
tunnelIP: client.subnet, tunnelIP: client.subnet,
utilitySubnet: org.utilitySubnet utilitySubnet: org.utilitySubnet,
chainId: chainId
} }
}, },
options: { options: {