Compare commits

...

15 Commits
dev ... jit

Author SHA1 Message Date
Owen
1a43f1ef4b Handle newt online offline with websocket 2026-03-14 11:59:20 -07:00
Owen
75ab074805 Attempt to improve handling bandwidth tracking 2026-03-13 12:06:01 -07:00
Owen
dc4e0253de Add message compression for large messages 2026-03-13 11:46:03 -07:00
Owen
cccf236042 Add optional compression 2026-03-12 17:49:21 -07:00
Owen
63fd63c65c Send less data down 2026-03-12 17:27:15 -07:00
Owen
beee1d692d revert: telemetry comment 2026-03-12 17:11:13 -07:00
Owen
fde786ca84 Add todo 2026-03-12 17:10:46 -07:00
Owen
3086fdd064 Merge branch 'dev' into jit 2026-03-12 16:58:23 -07:00
Owen
cf5fb8dc33 Working on jit 2026-03-09 16:36:13 -07:00
Owen
0503c6e66e Handle JIT for ssh 2026-03-06 15:49:17 -08:00
Owen
9405b0b70a Force jit above site limit 2026-03-06 14:09:57 -08:00
Owen
2a5c9465e9 Add chainId field passthrough 2026-03-04 22:17:58 -08:00
Owen
f36b66e397 Merge branch 'dev' into jit 2026-03-04 17:58:50 -08:00
Owen
1bfff630bf Jit working for sites 2026-03-04 17:46:58 -08:00
Owen
c73a39f797 Allow JIT based on site or resource 2026-03-04 15:44:27 -08:00
33 changed files with 1247 additions and 599 deletions

View File

@@ -1,6 +1,10 @@
import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage";
import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth";
import { cleanup as wsCleanup } from "#dynamic/routers/ws"; import { cleanup as wsCleanup } from "#dynamic/routers/ws";
async function cleanup() { async function cleanup() {
await flushBandwidthToDb();
await flushSiteBandwidthToDb();
await wsCleanup(); await wsCleanup();
process.exit(0); process.exit(0);

View File

@@ -89,6 +89,7 @@ export const sites = pgTable("sites", {
lastBandwidthUpdate: varchar("lastBandwidthUpdate"), lastBandwidthUpdate: varchar("lastBandwidthUpdate"),
type: varchar("type").notNull(), // "newt" or "wireguard" type: varchar("type").notNull(), // "newt" or "wireguard"
online: boolean("online").notNull().default(false), online: boolean("online").notNull().default(false),
lastPing: integer("lastPing"),
address: varchar("address"), address: varchar("address"),
endpoint: varchar("endpoint"), endpoint: varchar("endpoint"),
publicKey: varchar("publicKey"), publicKey: varchar("publicKey"),
@@ -721,6 +722,7 @@ export const clientSitesAssociationsCache = pgTable(
.notNull(), .notNull(),
siteId: integer("siteId").notNull(), siteId: integer("siteId").notNull(),
isRelayed: boolean("isRelayed").notNull().default(false), isRelayed: boolean("isRelayed").notNull().default(false),
isJitMode: boolean("isJitMode").notNull().default(false),
endpoint: varchar("endpoint"), endpoint: varchar("endpoint"),
publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
} }

View File

@@ -90,6 +90,7 @@ export const sites = sqliteTable("sites", {
lastBandwidthUpdate: text("lastBandwidthUpdate"), lastBandwidthUpdate: text("lastBandwidthUpdate"),
type: text("type").notNull(), // "newt" or "wireguard" type: text("type").notNull(), // "newt" or "wireguard"
online: integer("online", { mode: "boolean" }).notNull().default(false), online: integer("online", { mode: "boolean" }).notNull().default(false),
lastPing: integer("lastPing"),
// exit node stuff that is how to connect to the site when it has a wg server // exit node stuff that is how to connect to the site when it has a wg server
address: text("address"), // this is the address of the wireguard interface in newt address: text("address"), // this is the address of the wireguard interface in newt
@@ -410,6 +411,9 @@ export const clientSitesAssociationsCache = sqliteTable(
isRelayed: integer("isRelayed", { mode: "boolean" }) isRelayed: integer("isRelayed", { mode: "boolean" })
.notNull() .notNull()
.default(false), .default(false),
isJitMode: integer("isJitMode", { mode: "boolean" })
.notNull()
.default(false),
endpoint: text("endpoint"), endpoint: text("endpoint"),
publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
} }

View File

@@ -107,7 +107,7 @@ export async function applyBlueprint({
[target], [target],
matchingHealthcheck ? [matchingHealthcheck] : [], matchingHealthcheck ? [matchingHealthcheck] : [],
result.proxyResource.protocol, result.proxyResource.protocol,
result.proxyResource.proxyPort site.newt.version
); );
} }
} }

View File

@@ -0,0 +1,20 @@
import semver from "semver";
export function canCompress(
clientVersion: string | null | undefined,
type: "newt" | "olm"
): boolean {
try {
if (!clientVersion) return false;
// check if it is a valid semver
if (!semver.valid(clientVersion)) return false;
if (type === "newt") {
return semver.gte(clientVersion, "1.10.3");
} else if (type === "olm") {
return semver.gte(clientVersion, "1.4.3");
}
return false;
} catch {
return false;
}
}

View File

@@ -477,6 +477,7 @@ async function handleMessagesForSiteClients(
} }
if (isAdd) { if (isAdd) {
// TODO: if we are in jit mode here should we really be sending this?
await initPeerAddHandshake( await initPeerAddHandshake(
// this will kick off the add peer process for the client // this will kick off the add peer process for the client
client.clientId, client.clientId,
@@ -669,7 +670,11 @@ async function handleSubnetProxyTargetUpdates(
`Adding ${targetsToAdd.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}` `Adding ${targetsToAdd.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}`
); );
proxyJobs.push( proxyJobs.push(
addSubnetProxyTargets(newt.newtId, targetsToAdd) addSubnetProxyTargets(
newt.newtId,
targetsToAdd,
newt.version
)
); );
} }
@@ -705,7 +710,11 @@ async function handleSubnetProxyTargetUpdates(
`Removing ${targetsToRemove.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}` `Removing ${targetsToRemove.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}`
); );
proxyJobs.push( proxyJobs.push(
removeSubnetProxyTargets(newt.newtId, targetsToRemove) removeSubnetProxyTargets(
newt.newtId,
targetsToRemove,
newt.version
)
); );
} }
@@ -1080,6 +1089,7 @@ async function handleMessagesForClientSites(
continue; continue;
} }
// TODO: if we are in jit mode here should we really be sending this?
await initPeerAddHandshake( await initPeerAddHandshake(
// this will kick off the add peer process for the client // this will kick off the add peer process for the client
client.clientId, client.clientId,
@@ -1146,7 +1156,7 @@ async function handleMessagesForClientResources(
// Add subnet proxy targets for each site // Add subnet proxy targets for each site
for (const [siteId, resources] of addedBySite.entries()) { for (const [siteId, resources] of addedBySite.entries()) {
const [newt] = await trx const [newt] = await trx
.select({ newtId: newts.newtId }) .select({ newtId: newts.newtId, version: newts.version })
.from(newts) .from(newts)
.where(eq(newts.siteId, siteId)) .where(eq(newts.siteId, siteId))
.limit(1); .limit(1);
@@ -1168,7 +1178,13 @@ async function handleMessagesForClientResources(
]); ]);
if (targets.length > 0) { if (targets.length > 0) {
proxyJobs.push(addSubnetProxyTargets(newt.newtId, targets)); proxyJobs.push(
addSubnetProxyTargets(
newt.newtId,
targets,
newt.version
)
);
} }
try { try {
@@ -1217,7 +1233,7 @@ async function handleMessagesForClientResources(
// Remove subnet proxy targets for each site // Remove subnet proxy targets for each site
for (const [siteId, resources] of removedBySite.entries()) { for (const [siteId, resources] of removedBySite.entries()) {
const [newt] = await trx const [newt] = await trx
.select({ newtId: newts.newtId }) .select({ newtId: newts.newtId, version: newts.version })
.from(newts) .from(newts)
.where(eq(newts.siteId, siteId)) .where(eq(newts.siteId, siteId))
.limit(1); .limit(1);
@@ -1240,7 +1256,11 @@ async function handleMessagesForClientResources(
if (targets.length > 0) { if (targets.length > 0) {
proxyJobs.push( proxyJobs.push(
removeSubnetProxyTargets(newt.newtId, targets) removeSubnetProxyTargets(
newt.newtId,
targets,
newt.version
)
); );
} }

View File

@@ -13,8 +13,12 @@
import { rateLimitService } from "#private/lib/rateLimit"; import { rateLimitService } from "#private/lib/rateLimit";
import { cleanup as wsCleanup } from "#private/routers/ws"; import { cleanup as wsCleanup } from "#private/routers/ws";
import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage";
import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth";
async function cleanup() { async function cleanup() {
await flushBandwidthToDb();
await flushSiteBandwidthToDb();
await rateLimitService.cleanup(); await rateLimitService.cleanup();
await wsCleanup(); await wsCleanup();

View File

@@ -29,7 +29,6 @@ import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors"; import createHttpError from "http-errors";
import logger from "@server/logger"; import logger from "@server/logger";
import { fromError } from "zod-validation-error"; import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { eq, or, and } from "drizzle-orm"; import { eq, or, and } from "drizzle-orm";
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource"; import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
import { signPublicKey, getOrgCAKeys } from "@server/lib/sshCA"; import { signPublicKey, getOrgCAKeys } from "@server/lib/sshCA";
@@ -64,6 +63,7 @@ export type SignSshKeyResponse = {
sshUsername: string; sshUsername: string;
sshHost: string; sshHost: string;
resourceId: number; resourceId: number;
siteId: number;
keyId: string; keyId: string;
validPrincipals: string[]; validPrincipals: string[];
validAfter: string; validAfter: string;
@@ -453,6 +453,7 @@ export async function signSshKey(
sshUsername: usernameToUse, sshUsername: usernameToUse,
sshHost: sshHost, sshHost: sshHost,
resourceId: resource.siteResourceId, resourceId: resource.siteResourceId,
siteId: resource.siteId,
keyId: cert.keyId, keyId: cert.keyId,
validPrincipals: cert.validPrincipals, validPrincipals: cert.validPrincipals,
validAfter: cert.validAfter.toISOString(), validAfter: cert.validAfter.toISOString(),

View File

@@ -12,6 +12,7 @@
*/ */
import { Router, Request, Response } from "express"; import { Router, Request, Response } from "express";
import zlib from "zlib";
import { Server as HttpServer } from "http"; import { Server as HttpServer } from "http";
import { WebSocket, WebSocketServer } from "ws"; import { WebSocket, WebSocketServer } from "ws";
import { Socket } from "net"; import { Socket } from "net";
@@ -24,7 +25,8 @@ import {
OlmSession, OlmSession,
RemoteExitNode, RemoteExitNode,
RemoteExitNodeSession, RemoteExitNodeSession,
remoteExitNodes remoteExitNodes,
sites
} from "@server/db"; } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { db } from "@server/db"; import { db } from "@server/db";
@@ -57,11 +59,13 @@ const MAX_PENDING_MESSAGES = 50; // Maximum messages to queue during connection
const processMessage = async ( const processMessage = async (
ws: AuthenticatedWebSocket, ws: AuthenticatedWebSocket,
data: Buffer, data: Buffer,
isBinary: boolean,
clientId: string, clientId: string,
clientType: ClientType clientType: ClientType
): Promise<void> => { ): Promise<void> => {
try { try {
const message: WSMessage = JSON.parse(data.toString()); const messageBuffer = isBinary ? zlib.gunzipSync(data) : data;
const message: WSMessage = JSON.parse(messageBuffer.toString());
// logger.debug( // logger.debug(
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}` // `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
@@ -76,7 +80,7 @@ const processMessage = async (
clientId, clientId,
message.type, // Pass message type for granular limiting message.type, // Pass message type for granular limiting
100, // max requests per window 100, // max requests per window
20, // max requests per message type per window 100, // max requests per message type per window
60 * 1000 // window in milliseconds 60 * 1000 // window in milliseconds
); );
if (rateLimitResult.isLimited) { if (rateLimitResult.isLimited) {
@@ -163,8 +167,16 @@ const processPendingMessages = async (
); );
const jobs = []; const jobs = [];
for (const messageData of ws.pendingMessages) { for (const pending of ws.pendingMessages) {
jobs.push(processMessage(ws, messageData, clientId, clientType)); jobs.push(
processMessage(
ws,
pending.data,
pending.isBinary,
clientId,
clientType
)
);
} }
await Promise.all(jobs); await Promise.all(jobs);
@@ -325,7 +337,9 @@ const addClient = async (
// Check Redis first if enabled // Check Redis first if enabled
if (redisManager.isRedisEnabled()) { if (redisManager.isRedisEnabled()) {
try { try {
const redisVersion = await redisManager.get(getConfigVersionKey(clientId)); const redisVersion = await redisManager.get(
getConfigVersionKey(clientId)
);
if (redisVersion !== null) { if (redisVersion !== null) {
configVersion = parseInt(redisVersion, 10); configVersion = parseInt(redisVersion, 10);
// Sync to local cache // Sync to local cache
@@ -337,7 +351,10 @@ const addClient = async (
} else { } else {
// Use local cache version and sync to Redis // Use local cache version and sync to Redis
configVersion = clientConfigVersions.get(clientId) || 0; configVersion = clientConfigVersions.get(clientId) || 0;
await redisManager.set(getConfigVersionKey(clientId), configVersion.toString()); await redisManager.set(
getConfigVersionKey(clientId),
configVersion.toString()
);
} }
} catch (error) { } catch (error) {
logger.error("Failed to get/set config version in Redis:", error); logger.error("Failed to get/set config version in Redis:", error);
@@ -432,7 +449,9 @@ const removeClient = async (
}; };
// Helper to get the current config version for a client // Helper to get the current config version for a client
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => { const getClientConfigVersion = async (
clientId: string
): Promise<number | undefined> => {
// Try Redis first if available // Try Redis first if available
if (redisManager.isRedisEnabled()) { if (redisManager.isRedisEnabled()) {
try { try {
@@ -502,11 +521,26 @@ const sendToClientLocal = async (
}; };
const messageString = JSON.stringify(messageWithVersion); const messageString = JSON.stringify(messageWithVersion);
if (options.compress) {
logger.debug(
`Message size before compression: ${messageString.length} bytes`
);
const compressed = zlib.gzipSync(Buffer.from(messageString, "utf8"));
logger.debug(
`Message size after compression: ${compressed.length} bytes`
);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(messageString); client.send(messageString);
} }
}); });
}
return true; return true;
}; };
@@ -532,6 +566,16 @@ const broadcastToAllExceptLocal = async (
configVersion configVersion
}; };
if (options.compress) {
const compressed = zlib.gzipSync(
Buffer.from(JSON.stringify(messageWithVersion), "utf8")
);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion)); client.send(JSON.stringify(messageWithVersion));
@@ -539,6 +583,7 @@ const broadcastToAllExceptLocal = async (
}); });
} }
} }
}
}; };
// Cross-node message sending (via Redis) // Cross-node message sending (via Redis)
@@ -762,7 +807,7 @@ const setupConnection = async (
} }
// Set up message handler FIRST to prevent race condition // Set up message handler FIRST to prevent race condition
ws.on("message", async (data) => { ws.on("message", async (data, isBinary) => {
if (!ws.isFullyConnected) { if (!ws.isFullyConnected) {
// Queue message for later processing with limits // Queue message for later processing with limits
ws.pendingMessages = ws.pendingMessages || []; ws.pendingMessages = ws.pendingMessages || [];
@@ -777,11 +822,17 @@ const setupConnection = async (
logger.debug( logger.debug(
`Queueing message from ${clientType.toUpperCase()} ID: ${clientId} (connection not fully established)` `Queueing message from ${clientType.toUpperCase()} ID: ${clientId} (connection not fully established)`
); );
ws.pendingMessages.push(data as Buffer); ws.pendingMessages.push({ data: data as Buffer, isBinary });
return; return;
} }
await processMessage(ws, data as Buffer, clientId, clientType); await processMessage(
ws,
data as Buffer,
isBinary,
clientId,
clientType
);
}); });
// Set up other event handlers before async operations // Set up other event handlers before async operations
@@ -796,6 +847,31 @@ const setupConnection = async (
); );
}); });
// Handle WebSocket protocol-level pings from older newt clients that do
// not send application-level "newt/ping" messages. Update the site's
// online state and lastPing timestamp so the offline checker treats them
// the same as modern newt clients.
if (clientType === "newt") {
const newtClient = client as Newt;
ws.on("ping", async () => {
if (!newtClient.siteId) return;
try {
await db
.update(sites)
.set({
online: true,
lastPing: Math.floor(Date.now() / 1000)
})
.where(eq(sites.siteId, newtClient.siteId));
} catch (error) {
logger.error(
"Error updating newt site online state on WS ping",
{ error }
);
}
});
}
ws.on("error", (error: Error) => { ws.on("error", (error: Error) => {
logger.error( logger.error(
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`, `WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,

View File

@@ -1,51 +1,38 @@
import { sendToClient } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import { db, olms, Transaction } from "@server/db"; import { db, olms, Transaction } from "@server/db";
import { canCompress } from "@server/lib/clientVersionChecks";
import { Alias, SubnetProxyTarget } from "@server/lib/ip"; import { Alias, SubnetProxyTarget } from "@server/lib/ip";
import logger from "@server/logger"; import logger from "@server/logger";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
const BATCH_SIZE = 50; export async function addTargets(
const BATCH_DELAY_MS = 50; newtId: string,
targets: SubnetProxyTarget[],
function sleep(ms: number): Promise<void> { version?: string | null
return new Promise((resolve) => setTimeout(resolve, ms)); ) {
} await sendToClient(
newtId,
function chunkArray<T>(array: T[], size: number): T[][] { {
const chunks: T[][] = [];
for (let i = 0; i < array.length; i += size) {
chunks.push(array.slice(i, i + size));
}
return chunks;
}
export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
const batches = chunkArray(targets, BATCH_SIZE);
for (let i = 0; i < batches.length; i++) {
if (i > 0) {
await sleep(BATCH_DELAY_MS);
}
await sendToClient(newtId, {
type: `newt/wg/targets/add`, type: `newt/wg/targets/add`,
data: batches[i] data: targets
}, { incrementConfigVersion: true }); },
} { incrementConfigVersion: true, compress: canCompress(version, "newt") }
);
} }
export async function removeTargets( export async function removeTargets(
newtId: string, newtId: string,
targets: SubnetProxyTarget[] targets: SubnetProxyTarget[],
version?: string | null
) { ) {
const batches = chunkArray(targets, BATCH_SIZE); await sendToClient(
for (let i = 0; i < batches.length; i++) { newtId,
if (i > 0) { {
await sleep(BATCH_DELAY_MS);
}
await sendToClient(newtId, {
type: `newt/wg/targets/remove`, type: `newt/wg/targets/remove`,
data: batches[i] data: targets
},{ incrementConfigVersion: true }); },
} { incrementConfigVersion: true, compress: canCompress(version, "newt") }
);
} }
export async function updateTargets( export async function updateTargets(
@@ -53,34 +40,31 @@ export async function updateTargets(
targets: { targets: {
oldTargets: SubnetProxyTarget[]; oldTargets: SubnetProxyTarget[];
newTargets: SubnetProxyTarget[]; newTargets: SubnetProxyTarget[];
} },
version?: string | null
) { ) {
const oldBatches = chunkArray(targets.oldTargets, BATCH_SIZE); await sendToClient(
const newBatches = chunkArray(targets.newTargets, BATCH_SIZE); newtId,
const maxBatches = Math.max(oldBatches.length, newBatches.length); {
for (let i = 0; i < maxBatches; i++) {
if (i > 0) {
await sleep(BATCH_DELAY_MS);
}
await sendToClient(newtId, {
type: `newt/wg/targets/update`, type: `newt/wg/targets/update`,
data: { data: {
oldTargets: oldBatches[i] || [], oldTargets: targets.oldTargets,
newTargets: newBatches[i] || [] newTargets: targets.newTargets
} }
}, { incrementConfigVersion: true }).catch((error) => { },
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }
}
export async function addPeerData( export async function addPeerData(
clientId: number, clientId: number,
siteId: number, siteId: number,
remoteSubnets: string[], remoteSubnets: string[],
aliases: Alias[], aliases: Alias[],
olmId?: string olmId?: string,
version?: string | null
) { ) {
if (!olmId) { if (!olmId) {
const [olm] = await db const [olm] = await db
@@ -92,16 +76,21 @@ export async function addPeerData(
return; // ignore this because an olm might not be associated with the client anymore return; // ignore this because an olm might not be associated with the client anymore
} }
olmId = olm.olmId; olmId = olm.olmId;
version = olm.version;
} }
await sendToClient(olmId, { await sendToClient(
olmId,
{
type: `olm/wg/peer/data/add`, type: `olm/wg/peer/data/add`,
data: { data: {
siteId: siteId, siteId: siteId,
remoteSubnets: remoteSubnets, remoteSubnets: remoteSubnets,
aliases: aliases aliases: aliases
} }
}, { incrementConfigVersion: true }).catch((error) => { },
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }
@@ -111,7 +100,8 @@ export async function removePeerData(
siteId: number, siteId: number,
remoteSubnets: string[], remoteSubnets: string[],
aliases: Alias[], aliases: Alias[],
olmId?: string olmId?: string,
version?: string | null
) { ) {
if (!olmId) { if (!olmId) {
const [olm] = await db const [olm] = await db
@@ -123,16 +113,21 @@ export async function removePeerData(
return; return;
} }
olmId = olm.olmId; olmId = olm.olmId;
version = olm.version;
} }
await sendToClient(olmId, { await sendToClient(
olmId,
{
type: `olm/wg/peer/data/remove`, type: `olm/wg/peer/data/remove`,
data: { data: {
siteId: siteId, siteId: siteId,
remoteSubnets: remoteSubnets, remoteSubnets: remoteSubnets,
aliases: aliases aliases: aliases
} }
}, { incrementConfigVersion: true }).catch((error) => { },
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }
@@ -152,7 +147,8 @@ export async function updatePeerData(
newAliases: Alias[]; newAliases: Alias[];
} }
| undefined, | undefined,
olmId?: string olmId?: string,
version?: string | null
) { ) {
if (!olmId) { if (!olmId) {
const [olm] = await db const [olm] = await db
@@ -164,16 +160,21 @@ export async function updatePeerData(
return; return;
} }
olmId = olm.olmId; olmId = olm.olmId;
version = olm.version;
} }
await sendToClient(olmId, { await sendToClient(
olmId,
{
type: `olm/wg/peer/data/update`, type: `olm/wg/peer/data/update`,
data: { data: {
siteId: siteId, siteId: siteId,
...remoteSubnets, ...remoteSubnets,
...aliases ...aliases
} }
}, { incrementConfigVersion: true }).catch((error) => { },
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
} }

View File

@@ -1,5 +1,5 @@
import { Request, Response, NextFunction } from "express"; import { Request, Response, NextFunction } from "express";
import { eq, and, lt, inArray, sql } from "drizzle-orm"; import { eq, sql } from "drizzle-orm";
import { sites } from "@server/db"; import { sites } from "@server/db";
import { db } from "@server/db"; import { db } from "@server/db";
import logger from "@server/logger"; import logger from "@server/logger";
@@ -11,19 +11,31 @@ import { FeatureId } from "@server/lib/billing/features";
import { checkExitNodeOrg } from "#dynamic/lib/exitNodes"; import { checkExitNodeOrg } from "#dynamic/lib/exitNodes";
import { build } from "@server/build"; import { build } from "@server/build";
// Track sites that are already offline to avoid unnecessary queries
const offlineSites = new Set<string>();
// Retry configuration for deadlock handling
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50;
interface PeerBandwidth { interface PeerBandwidth {
publicKey: string; publicKey: string;
bytesIn: number; bytesIn: number;
bytesOut: number; bytesOut: number;
} }
interface AccumulatorEntry {
bytesIn: number;
bytesOut: number;
/** Present when the update came through a remote exit node. */
exitNodeId?: number;
/** Whether to record egress usage for billing purposes. */
calcUsage: boolean;
}
// Retry configuration for deadlock handling
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50;
// How often to flush accumulated bandwidth data to the database
const FLUSH_INTERVAL_MS = 30_000; // 30 seconds
// In-memory accumulator: publicKey -> AccumulatorEntry
let accumulator = new Map<string, AccumulatorEntry>();
/** /**
* Check if an error is a deadlock error * Check if an error is a deadlock error
*/ */
@@ -63,6 +75,220 @@ async function withDeadlockRetry<T>(
} }
} }
/**
* Flush all accumulated site bandwidth data to the database.
*
* Swaps out the accumulator before writing so that any bandwidth messages
* received during the flush are captured in the new accumulator rather than
* being lost or causing contention. Entries that fail to write are re-queued
* back into the accumulator so they will be retried on the next flush.
*
* This function is exported so that the application's graceful-shutdown
* cleanup handler can call it before the process exits.
*/
export async function flushSiteBandwidthToDb(): Promise<void> {
if (accumulator.size === 0) {
return;
}
// Atomically swap out the accumulator so new data keeps flowing in
// while we write the snapshot to the database.
const snapshot = accumulator;
accumulator = new Map<string, AccumulatorEntry>();
const currentTime = new Date().toISOString();
// Sort by publicKey for consistent lock ordering across concurrent
// writers — deadlock-prevention strategy.
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
a.localeCompare(b)
);
logger.debug(
`Flushing accumulated bandwidth data for ${sortedEntries.length} site(s) to the database`
);
// Aggregate billing usage by org, collected during the DB update loop.
const orgUsageMap = new Map<string, number>();
for (const [publicKey, { bytesIn, bytesOut, exitNodeId, calcUsage }] of sortedEntries) {
try {
const updatedSite = await withDeadlockRetry(async () => {
const [result] = await db
.update(sites)
.set({
megabytesOut: sql`COALESCE(${sites.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${sites.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime
})
.where(eq(sites.pubKey, publicKey))
.returning({
orgId: sites.orgId,
siteId: sites.siteId
});
return result;
}, `flush bandwidth for site ${publicKey}`);
if (updatedSite) {
if (exitNodeId) {
const notAllowed = await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
);
if (notAllowed) {
logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
);
// Skip usage tracking for this site but continue
// processing the rest.
continue;
}
}
if (calcUsage) {
const totalBandwidth = bytesIn + bytesOut;
const current = orgUsageMap.get(updatedSite.orgId) ?? 0;
orgUsageMap.set(updatedSite.orgId, current + totalBandwidth);
}
}
} catch (error) {
logger.error(
`Failed to flush bandwidth for site ${publicKey}:`,
error
);
// Re-queue the failed entry so it is retried on the next flush
// rather than silently dropped.
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
} else {
accumulator.set(publicKey, {
bytesIn,
bytesOut,
exitNodeId,
calcUsage
});
}
}
}
// Process billing usage updates outside the site-update loop to keep
// lock scope small and concerns separated.
if (orgUsageMap.size > 0) {
// Sort org IDs for consistent lock ordering.
const sortedOrgIds = [...orgUsageMap.keys()].sort();
for (const orgId of sortedOrgIds) {
try {
const totalBandwidth = orgUsageMap.get(orgId)!;
const bandwidthUsage = await usageService.add(
orgId,
FeatureId.EGRESS_DATA_MB,
totalBandwidth
);
if (bandwidthUsage) {
// Fire-and-forget — don't block the flush on limit checking.
usageService
.checkLimitSet(
orgId,
FeatureId.EGRESS_DATA_MB,
bandwidthUsage
)
.catch((error: any) => {
logger.error(
`Error checking bandwidth limits for org ${orgId}:`,
error
);
});
}
} catch (error) {
logger.error(
`Error processing usage for org ${orgId}:`,
error
);
// Continue with other orgs.
}
}
}
}
// ---------------------------------------------------------------------------
// Periodic flush timer
// ---------------------------------------------------------------------------
const flushTimer = setInterval(async () => {
try {
await flushSiteBandwidthToDb();
} catch (error) {
logger.error(
"Unexpected error during periodic site bandwidth flush:",
error
);
}
}, FLUSH_INTERVAL_MS);
// Allow the process to exit normally even while the timer is pending.
// The graceful-shutdown path (see server/cleanup.ts) will call
// flushSiteBandwidthToDb() explicitly before process.exit(), so no data
// is lost.
flushTimer.unref();
// ---------------------------------------------------------------------------
// Public API
// ---------------------------------------------------------------------------
/**
* Accumulate bandwidth data reported by a gerbil or remote exit node.
*
* Only peers that actually transferred data (bytesIn > 0) are added to the
* accumulator; peers with no activity are silently ignored, which means the
* flush will only write rows that have genuinely changed.
*
* The function is intentionally synchronous in its fast path so that the
* HTTP handler can respond immediately without waiting for any I/O.
*/
export async function updateSiteBandwidth(
bandwidthData: PeerBandwidth[],
calcUsageAndLimits: boolean,
exitNodeId?: number
): Promise<void> {
for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
// Skip peers that haven't transferred any data — writing zeros to the
// database would be a no-op anyway.
if (bytesIn <= 0 && bytesOut <= 0) {
continue;
}
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
// Retain the most-recent exitNodeId for this peer.
if (exitNodeId !== undefined) {
existing.exitNodeId = exitNodeId;
}
// Once calcUsage has been requested for a peer, keep it set for
// the lifetime of this flush window.
if (calcUsageAndLimits) {
existing.calcUsage = true;
}
} else {
accumulator.set(publicKey, {
bytesIn,
bytesOut,
exitNodeId,
calcUsage: calcUsageAndLimits
});
}
}
}
// ---------------------------------------------------------------------------
// HTTP handler
// ---------------------------------------------------------------------------
export const receiveBandwidth = async ( export const receiveBandwidth = async (
req: Request, req: Request,
res: Response, res: Response,
@@ -75,7 +301,9 @@ export const receiveBandwidth = async (
throw new Error("Invalid bandwidth data"); throw new Error("Invalid bandwidth data");
} }
await updateSiteBandwidth(bandwidthData, build == "saas"); // we are checking the usage on saas only // Accumulate in memory; the periodic timer (and the shutdown hook)
// will write to the database.
await updateSiteBandwidth(bandwidthData, build == "saas");
return response(res, { return response(res, {
data: {}, data: {},
@@ -94,201 +322,3 @@ export const receiveBandwidth = async (
); );
} }
}; };
export async function updateSiteBandwidth(
bandwidthData: PeerBandwidth[],
calcUsageAndLimits: boolean,
exitNodeId?: number
) {
const currentTime = new Date();
const oneMinuteAgo = new Date(currentTime.getTime() - 60000); // 1 minute ago
// Sort bandwidth data by publicKey to ensure consistent lock ordering across all instances
// This is critical for preventing deadlocks when multiple instances update the same sites
const sortedBandwidthData = [...bandwidthData].sort((a, b) =>
a.publicKey.localeCompare(b.publicKey)
);
// First, handle sites that are actively reporting bandwidth
const activePeers = sortedBandwidthData.filter((peer) => peer.bytesIn > 0);
// Aggregate usage data by organization (collected outside transaction)
const orgUsageMap = new Map<string, number>();
if (activePeers.length > 0) {
// Remove any active peers from offline tracking since they're sending data
activePeers.forEach((peer) => offlineSites.delete(peer.publicKey));
// Update each active site individually with retry logic
// This reduces transaction scope and allows retries per-site
for (const peer of activePeers) {
try {
const updatedSite = await withDeadlockRetry(async () => {
const [result] = await db
.update(sites)
.set({
megabytesOut: sql`${sites.megabytesOut} + ${peer.bytesIn}`,
megabytesIn: sql`${sites.megabytesIn} + ${peer.bytesOut}`,
lastBandwidthUpdate: currentTime.toISOString(),
online: true
})
.where(eq(sites.pubKey, peer.publicKey))
.returning({
online: sites.online,
orgId: sites.orgId,
siteId: sites.siteId,
lastBandwidthUpdate: sites.lastBandwidthUpdate
});
return result;
}, `update active site ${peer.publicKey}`);
if (updatedSite) {
if (exitNodeId) {
const notAllowed = await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
);
if (notAllowed) {
logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
);
// Skip this site but continue processing others
continue;
}
}
// Aggregate bandwidth usage for the org
const totalBandwidth = peer.bytesIn + peer.bytesOut;
const currentOrgUsage =
orgUsageMap.get(updatedSite.orgId) || 0;
orgUsageMap.set(
updatedSite.orgId,
currentOrgUsage + totalBandwidth
);
}
} catch (error) {
logger.error(
`Failed to update bandwidth for site ${peer.publicKey}:`,
error
);
// Continue with other sites
}
}
}
// Process usage updates outside of site update transactions
// This separates the concerns and reduces lock contention
if (calcUsageAndLimits && orgUsageMap.size > 0) {
// Sort org IDs to ensure consistent lock ordering
const allOrgIds = [...new Set([...orgUsageMap.keys()])].sort();
for (const orgId of allOrgIds) {
try {
// Process bandwidth usage for this org
const totalBandwidth = orgUsageMap.get(orgId);
if (totalBandwidth) {
const bandwidthUsage = await usageService.add(
orgId,
FeatureId.EGRESS_DATA_MB,
totalBandwidth
);
if (bandwidthUsage) {
// Fire and forget - don't block on limit checking
usageService
.checkLimitSet(
orgId,
FeatureId.EGRESS_DATA_MB,
bandwidthUsage
)
.catch((error: any) => {
logger.error(
`Error checking bandwidth limits for org ${orgId}:`,
error
);
});
}
}
} catch (error) {
logger.error(`Error processing usage for org ${orgId}:`, error);
// Continue with other orgs
}
}
}
// Handle sites that reported zero bandwidth but need online status updated
const zeroBandwidthPeers = sortedBandwidthData.filter(
(peer) => peer.bytesIn === 0 && !offlineSites.has(peer.publicKey)
);
if (zeroBandwidthPeers.length > 0) {
// Fetch all zero bandwidth sites in one query
const zeroBandwidthSites = await db
.select()
.from(sites)
.where(
inArray(
sites.pubKey,
zeroBandwidthPeers.map((p) => p.publicKey)
)
);
// Sort by siteId to ensure consistent lock ordering
const sortedZeroBandwidthSites = zeroBandwidthSites.sort(
(a, b) => a.siteId - b.siteId
);
for (const site of sortedZeroBandwidthSites) {
let newOnlineStatus = site.online;
// Check if site should go offline based on last bandwidth update WITH DATA
if (site.lastBandwidthUpdate) {
const lastUpdateWithData = new Date(site.lastBandwidthUpdate);
if (lastUpdateWithData < oneMinuteAgo) {
newOnlineStatus = false;
}
} else {
// No previous data update recorded, set to offline
newOnlineStatus = false;
}
// Only update online status if it changed
if (site.online !== newOnlineStatus) {
try {
const updatedSite = await withDeadlockRetry(async () => {
const [result] = await db
.update(sites)
.set({
online: newOnlineStatus
})
.where(eq(sites.siteId, site.siteId))
.returning();
return result;
}, `update offline status for site ${site.siteId}`);
if (updatedSite && exitNodeId) {
const notAllowed = await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
);
if (notAllowed) {
logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
);
}
}
// If site went offline, add it to our tracking set
if (!newOnlineStatus && site.pubKey) {
offlineSites.add(site.pubKey);
}
} catch (error) {
logger.error(
`Failed to update offline status for site ${site.siteId}:`,
error
);
// Continue with other sites
}
}
}
}
}

View File

@@ -1,4 +1,15 @@
import { clients, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, ExitNode, resources, Site, siteResources, targetHealthCheck, targets } from "@server/db"; import {
clients,
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
ExitNode,
resources,
Site,
siteResources,
targetHealthCheck,
targets
} from "@server/db";
import logger from "@server/logger"; import logger from "@server/logger";
import { initPeerAddHandshake, updatePeer } from "../olm/peers"; import { initPeerAddHandshake, updatePeer } from "../olm/peers";
import { eq, and } from "drizzle-orm"; import { eq, and } from "drizzle-orm";
@@ -69,6 +80,7 @@ export async function buildClientConfigurationForNewtClient(
// ) // )
// ); // );
if (!client.clientSitesAssociationsCache.isJitMode) { // if we are adding sites through jit then dont add the site to the olm
// update the peer info on the olm // update the peer info on the olm
// if the peer has not been added yet this will be a no-op // if the peer has not been added yet this will be a no-op
await updatePeer(client.clients.clientId, { await updatePeer(client.clients.clientId, {
@@ -103,6 +115,7 @@ export async function buildClientConfigurationForNewtClient(
} }
} }
); );
}
return { return {
publicKey: client.clients.pubKey!, publicKey: client.clients.pubKey!,
@@ -230,9 +243,9 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
!target.hcInterval || !target.hcInterval ||
!target.hcMethod !target.hcMethod
) { ) {
logger.debug( // logger.debug(
`Skipping adding target health check ${target.targetId} due to missing health check fields` // `Skipping adding target health check ${target.targetId} due to missing health check fields`
); // );
return null; // Skip targets with missing health check fields return null; // Skip targets with missing health check fields
} }

View File

@@ -6,6 +6,7 @@ import { db, ExitNode, exitNodes, Newt, sites } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { sendToExitNode } from "#dynamic/lib/exitNodes"; import { sendToExitNode } from "#dynamic/lib/exitNodes";
import { buildClientConfigurationForNewtClient } from "./buildConfiguration"; import { buildClientConfigurationForNewtClient } from "./buildConfiguration";
import { canCompress } from "@server/lib/clientVersionChecks";
const inputSchema = z.object({ const inputSchema = z.object({
publicKey: z.string(), publicKey: z.string(),
@@ -135,6 +136,9 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
targets targets
} }
}, },
options: {
compress: canCompress(newt.version, "newt")
},
broadcast: false, broadcast: false,
excludeSender: false excludeSender: false
}; };

View File

@@ -1,105 +1,107 @@
import { db, sites } from "@server/db"; import { db, newts, sites } from "@server/db";
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws"; import { hasActiveConnections, getClientConfigVersion } from "#dynamic/routers/ws";
import { MessageHandler } from "@server/routers/ws"; import { MessageHandler } from "@server/routers/ws";
import { clients, Newt } from "@server/db"; import { Newt } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm"; import { eq, lt, isNull, and, or } from "drizzle-orm";
import logger from "@server/logger"; import logger from "@server/logger";
import { validateSessionToken } from "@server/auth/sessions/app";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { sendTerminateClient } from "../client/terminate";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { sendNewtSyncMessage } from "./sync"; import { sendNewtSyncMessage } from "./sync";
// Track if the offline checker interval is running // Track if the offline checker interval is running
// let offlineCheckerInterval: NodeJS.Timeout | null = null; let offlineCheckerInterval: NodeJS.Timeout | null = null;
// const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
// const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
/** /**
* Starts the background interval that checks for clients that haven't pinged recently * Starts the background interval that checks for newt sites that haven't
* and marks them as offline * pinged recently and marks them as offline. For backward compatibility,
* a site is only marked offline when there is no active WebSocket connection
* either — so older newt versions that don't send pings but remain connected
* continue to be treated as online.
*/ */
// export const startNewtOfflineChecker = (): void => { export const startNewtOfflineChecker = (): void => {
// if (offlineCheckerInterval) { if (offlineCheckerInterval) {
// return; // Already running return; // Already running
// } }
// offlineCheckerInterval = setInterval(async () => { offlineCheckerInterval = setInterval(async () => {
// try { try {
// const twoMinutesAgo = Math.floor( const twoMinutesAgo = Math.floor(
// (Date.now() - OFFLINE_THRESHOLD_MS) / 1000 (Date.now() - OFFLINE_THRESHOLD_MS) / 1000
// ); );
// // TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING // Find all online newt-type sites that haven't pinged recently
// (or have never pinged at all). Join newts to obtain the newtId
// needed for the WebSocket connection check.
const staleSites = await db
.select({
siteId: sites.siteId,
newtId: newts.newtId,
lastPing: sites.lastPing
})
.from(sites)
.innerJoin(newts, eq(newts.siteId, sites.siteId))
.where(
and(
eq(sites.online, true),
eq(sites.type, "newt"),
or(
lt(sites.lastPing, twoMinutesAgo),
isNull(sites.lastPing)
)
)
);
// // Find clients that haven't pinged in the last 2 minutes and mark them as offline for (const staleSite of staleSites) {
// const offlineClients = await db // Backward-compatibility check: if the newt still has an
// .update(clients) // active WebSocket connection (older clients that don't send
// .set({ online: false }) // pings), keep the site online.
// .where( const isConnected = await hasActiveConnections(staleSite.newtId);
// and( if (isConnected) {
// eq(clients.online, true), logger.debug(
// or( `Newt ${staleSite.newtId} has not pinged recently but is still connected via WebSocket — keeping site ${staleSite.siteId} online`
// lt(clients.lastPing, twoMinutesAgo), );
// isNull(clients.lastPing) continue;
// ) }
// )
// )
// .returning();
// for (const offlineClient of offlineClients) { logger.info(
// logger.info( `Marking site ${staleSite.siteId} offline: newt ${staleSite.newtId} has no recent ping and no active WebSocket connection`
// `Kicking offline newt client ${offlineClient.clientId} due to inactivity` );
// );
// if (!offlineClient.newtId) { await db
// logger.warn( .update(sites)
// `Offline client ${offlineClient.clientId} has no newtId, cannot disconnect` .set({ online: false })
// ); .where(eq(sites.siteId, staleSite.siteId));
// continue; }
// } } catch (error) {
logger.error("Error in newt offline checker interval", { error });
}
}, OFFLINE_CHECK_INTERVAL);
// // Send a disconnect message to the client if connected logger.debug("Started newt offline checker interval");
// try { };
// await sendTerminateClient(
// offlineClient.clientId,
// offlineClient.newtId
// ); // terminate first
// // wait a moment to ensure the message is sent
// await new Promise((resolve) => setTimeout(resolve, 1000));
// await disconnectClient(offlineClient.newtId);
// } catch (error) {
// logger.error(
// `Error sending disconnect to offline newt ${offlineClient.clientId}`,
// { error }
// );
// }
// }
// } catch (error) {
// logger.error("Error in offline checker interval", { error });
// }
// }, OFFLINE_CHECK_INTERVAL);
// logger.debug("Started offline checker interval");
// };
/** /**
* Stops the background interval that checks for offline clients * Stops the background interval that checks for offline newt sites.
*/ */
// export const stopNewtOfflineChecker = (): void => { export const stopNewtOfflineChecker = (): void => {
// if (offlineCheckerInterval) { if (offlineCheckerInterval) {
// clearInterval(offlineCheckerInterval); clearInterval(offlineCheckerInterval);
// offlineCheckerInterval = null; offlineCheckerInterval = null;
// logger.info("Stopped offline checker interval"); logger.info("Stopped newt offline checker interval");
// } }
// }; };
/** /**
* Handles ping messages from clients and responds with pong * Handles ping messages from newt clients.
*
* On each ping:
* - Marks the associated site as online.
* - Records the current timestamp as the newt's last-ping time.
* - Triggers a config sync if the newt is running an outdated config version.
* - Responds with a pong message.
*/ */
export const handleNewtPingMessage: MessageHandler = async (context) => { export const handleNewtPingMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context; const { message, client: c } = context;
const newt = c as Newt; const newt = c as Newt;
if (!newt) { if (!newt) {
@@ -112,15 +114,31 @@ export const handleNewtPingMessage: MessageHandler = async (context) => {
return; return;
} }
// get the version try {
// Mark the site as online and record the ping timestamp.
await db
.update(sites)
.set({
online: true,
lastPing: Math.floor(Date.now() / 1000)
})
.where(eq(sites.siteId, newt.siteId));
} catch (error) {
logger.error("Error updating online state on newt ping", { error });
}
// Check config version and sync if stale.
const configVersion = await getClientConfigVersion(newt.newtId); const configVersion = await getClientConfigVersion(newt.newtId);
if (message.configVersion && configVersion != null && configVersion != message.configVersion) { if (
message.configVersion != null &&
configVersion != null &&
configVersion !== message.configVersion
) {
logger.warn( logger.warn(
`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})` `Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
); );
// get the site
const [site] = await db const [site] = await db
.select() .select()
.from(sites) .from(sites)
@@ -137,19 +155,6 @@ export const handleNewtPingMessage: MessageHandler = async (context) => {
await sendNewtSyncMessage(newt, site); await sendNewtSyncMessage(newt, site);
} }
// try {
// // Update the client's last ping timestamp
// await db
// .update(clients)
// .set({
// lastPing: Math.floor(Date.now() / 1000),
// online: true
// })
// .where(eq(clients.clientId, newt.clientId));
// } catch (error) {
// logger.error("Error handling ping message", { error });
// }
return { return {
message: { message: {
type: "pong", type: "pong",

View File

@@ -5,9 +5,7 @@ import { eq } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers"; import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger"; import logger from "@server/logger";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { import { findNextAvailableCidr } from "@server/lib/ip";
findNextAvailableCidr,
} from "@server/lib/ip";
import { import {
selectBestExitNode, selectBestExitNode,
verifyExitNodeOrgAccess verifyExitNodeOrgAccess
@@ -15,6 +13,7 @@ import {
import { fetchContainers } from "./dockerSocket"; import { fetchContainers } from "./dockerSocket";
import { lockManager } from "#dynamic/lib/lock"; import { lockManager } from "#dynamic/lib/lock";
import { buildTargetConfigurationForNewtClient } from "./buildConfiguration"; import { buildTargetConfigurationForNewtClient } from "./buildConfiguration";
import { canCompress } from "@server/lib/clientVersionChecks";
export type ExitNodePingResult = { export type ExitNodePingResult = {
exitNodeId: number; exitNodeId: number;
@@ -215,6 +214,9 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
healthCheckTargets: validHealthCheckTargets healthCheckTargets: validHealthCheckTargets
} }
}, },
options: {
compress: canCompress(newt.version, "newt")
},
broadcast: false, // Send to all clients broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast excludeSender: false // Include sender in broadcast
}; };

View File

@@ -10,10 +10,21 @@ interface PeerBandwidth {
bytesOut: number; bytesOut: number;
} }
interface BandwidthAccumulator {
bytesIn: number;
bytesOut: number;
}
// Retry configuration for deadlock handling // Retry configuration for deadlock handling
const MAX_RETRIES = 3; const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50; const BASE_DELAY_MS = 50;
// How often to flush accumulated bandwidth data to the database
const FLUSH_INTERVAL_MS = 120_000; // 120 seconds
// In-memory accumulator: publicKey -> { bytesIn, bytesOut }
let accumulator = new Map<string, BandwidthAccumulator>();
/** /**
* Check if an error is a deadlock error * Check if an error is a deadlock error
*/ */
@@ -53,6 +64,90 @@ async function withDeadlockRetry<T>(
} }
} }
/**
* Flush all accumulated bandwidth data to the database.
*
* Swaps out the accumulator before writing so that any bandwidth messages
* received during the flush are captured in the new accumulator rather than
* being lost or causing contention. Entries that fail to write are re-queued
* back into the accumulator so they will be retried on the next flush.
*
* This function is exported so that the application's graceful-shutdown
* cleanup handler can call it before the process exits.
*/
export async function flushBandwidthToDb(): Promise<void> {
if (accumulator.size === 0) {
return;
}
// Atomically swap out the accumulator so new data keeps flowing in
// while we write the snapshot to the database.
const snapshot = accumulator;
accumulator = new Map<string, BandwidthAccumulator>();
const currentTime = new Date().toISOString();
// Sort by publicKey for consistent lock ordering across concurrent
// writers — this is the same deadlock-prevention strategy used in the
// original per-message implementation.
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
a.localeCompare(b)
);
logger.debug(
`Flushing accumulated bandwidth data for ${sortedEntries.length} client(s) to the database`
);
for (const [publicKey, { bytesIn, bytesOut }] of sortedEntries) {
try {
await withDeadlockRetry(async () => {
// Use atomic SQL increment to avoid the SELECT-then-UPDATE
// anti-pattern and the races it would introduce.
await db
.update(clients)
.set({
// Note: bytesIn from peer goes to megabytesOut (data
// sent to client) and bytesOut from peer goes to
// megabytesIn (data received from client).
megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime
})
.where(eq(clients.pubKey, publicKey));
}, `flush bandwidth for client ${publicKey}`);
} catch (error) {
logger.error(
`Failed to flush bandwidth for client ${publicKey}:`,
error
);
// Re-queue the failed entry so it is retried on the next flush
// rather than silently dropped.
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
} else {
accumulator.set(publicKey, { bytesIn, bytesOut });
}
}
}
}
const flushTimer = setInterval(async () => {
try {
await flushBandwidthToDb();
} catch (error) {
logger.error("Unexpected error during periodic bandwidth flush:", error);
}
}, FLUSH_INTERVAL_MS);
// Calling unref() means this timer will not keep the Node.js event loop alive
// on its own — the process can still exit normally when there is no other work
// left. The graceful-shutdown path (see server/cleanup.ts) will call
// flushBandwidthToDb() explicitly before process.exit(), so no data is lost.
flushTimer.unref();
export const handleReceiveBandwidthMessage: MessageHandler = async ( export const handleReceiveBandwidthMessage: MessageHandler = async (
context context
) => { ) => {
@@ -69,40 +164,21 @@ export const handleReceiveBandwidthMessage: MessageHandler = async (
throw new Error("Invalid bandwidth data"); throw new Error("Invalid bandwidth data");
} }
// Sort bandwidth data by publicKey to ensure consistent lock ordering across all instances // Accumulate the incoming data in memory; the periodic timer (and the
// This is critical for preventing deadlocks when multiple instances update the same clients // shutdown hook) will take care of writing it to the database.
const sortedBandwidthData = [...bandwidthData].sort((a, b) => for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
a.publicKey.localeCompare(b.publicKey) // Skip peers that haven't transferred any data — writing zeros to the
); // database would be a no-op anyway.
if (bytesIn <= 0 && bytesOut <= 0) {
continue;
}
const currentTime = new Date().toISOString(); const existing = accumulator.get(publicKey);
if (existing) {
// Update each client individually with retry logic existing.bytesIn += bytesIn;
// This reduces transaction scope and allows retries per-client existing.bytesOut += bytesOut;
for (const peer of sortedBandwidthData) { } else {
const { publicKey, bytesIn, bytesOut } = peer; accumulator.set(publicKey, { bytesIn, bytesOut });
try {
await withDeadlockRetry(async () => {
// Use atomic SQL increment to avoid SELECT then UPDATE pattern
// This eliminates the need to read the current value first
await db
.update(clients)
.set({
// Note: bytesIn from peer goes to megabytesOut (data sent to client)
// and bytesOut from peer goes to megabytesIn (data received from client)
megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime
})
.where(eq(clients.pubKey, publicKey));
}, `update client bandwidth ${publicKey}`);
} catch (error) {
logger.error(
`Failed to update bandwidth for client ${publicKey}:`,
error
);
// Continue with other clients even if one fails
} }
} }
}; };

View File

@@ -6,6 +6,7 @@ import {
buildClientConfigurationForNewtClient, buildClientConfigurationForNewtClient,
buildTargetConfigurationForNewtClient buildTargetConfigurationForNewtClient
} from "./buildConfiguration"; } from "./buildConfiguration";
import { canCompress } from "@server/lib/clientVersionChecks";
export async function sendNewtSyncMessage(newt: Newt, site: Site) { export async function sendNewtSyncMessage(newt: Newt, site: Site) {
const { tcpTargets, udpTargets, validHealthCheckTargets } = const { tcpTargets, udpTargets, validHealthCheckTargets } =
@@ -24,7 +25,9 @@ export async function sendNewtSyncMessage(newt: Newt, site: Site) {
exitNode exitNode
); );
await sendToClient(newt.newtId, { await sendToClient(
newt.newtId,
{
type: "newt/sync", type: "newt/sync",
data: { data: {
proxyTargets: { proxyTargets: {
@@ -35,7 +38,11 @@ export async function sendNewtSyncMessage(newt: Newt, site: Site) {
peers: peers, peers: peers,
clientTargets: targets clientTargets: targets
} }
}).catch((error) => { },
{
compress: canCompress(newt.version, "newt")
}
).catch((error) => {
logger.warn(`Error sending newt sync message:`, error); logger.warn(`Error sending newt sync message:`, error);
}); });
} }

View File

@@ -2,13 +2,14 @@ import { Target, TargetHealthCheck, db, targetHealthCheck } from "@server/db";
import { sendToClient } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger"; import logger from "@server/logger";
import { eq, inArray } from "drizzle-orm"; import { eq, inArray } from "drizzle-orm";
import { canCompress } from "@server/lib/clientVersionChecks";
export async function addTargets( export async function addTargets(
newtId: string, newtId: string,
targets: Target[], targets: Target[],
healthCheckData: TargetHealthCheck[], healthCheckData: TargetHealthCheck[],
protocol: string, protocol: string,
port: number | null = null version?: string | null
) { ) {
//create a list of udp and tcp targets //create a list of udp and tcp targets
const payloadTargets = targets.map((target) => { const payloadTargets = targets.map((target) => {
@@ -22,7 +23,7 @@ export async function addTargets(
data: { data: {
targets: payloadTargets targets: payloadTargets
} }
}, { incrementConfigVersion: true }); }, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
// Create a map for quick lookup // Create a map for quick lookup
const healthCheckMap = new Map<number, TargetHealthCheck>(); const healthCheckMap = new Map<number, TargetHealthCheck>();
@@ -103,14 +104,14 @@ export async function addTargets(
data: { data: {
targets: validHealthCheckTargets targets: validHealthCheckTargets
} }
}, { incrementConfigVersion: true }); }, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
} }
export async function removeTargets( export async function removeTargets(
newtId: string, newtId: string,
targets: Target[], targets: Target[],
protocol: string, protocol: string,
port: number | null = null version?: string | null
) { ) {
//create a list of udp and tcp targets //create a list of udp and tcp targets
const payloadTargets = targets.map((target) => { const payloadTargets = targets.map((target) => {
@@ -135,5 +136,5 @@ export async function removeTargets(
data: { data: {
ids: healthCheckTargets ids: healthCheckTargets
} }
}, { incrementConfigVersion: true }); }, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
} }

View File

@@ -1,5 +1,17 @@
import { Client, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, exitNodes, siteResources, sites } from "@server/db"; import {
import { generateAliasConfig, generateRemoteSubnets } from "@server/lib/ip"; Client,
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
exitNodes,
siteResources,
sites
} from "@server/db";
import {
Alias,
generateAliasConfig,
generateRemoteSubnets
} from "@server/lib/ip";
import logger from "@server/logger"; import logger from "@server/logger";
import { and, eq } from "drizzle-orm"; import { and, eq } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers"; import { addPeer, deletePeer } from "../newt/peers";
@@ -8,9 +20,19 @@ import config from "@server/lib/config";
export async function buildSiteConfigurationForOlmClient( export async function buildSiteConfigurationForOlmClient(
client: Client, client: Client,
publicKey: string | null, publicKey: string | null,
relay: boolean relay: boolean,
jitMode: boolean = false
) { ) {
const siteConfigurations = []; const siteConfigurations: {
siteId: number;
name?: string
endpoint?: string
publicKey?: string
serverIP?: string | null
serverPort?: number | null
remoteSubnets?: string[];
aliases: Alias[];
}[] = [];
// Get all sites data // Get all sites data
const sitesData = await db const sitesData = await db
@@ -27,6 +49,40 @@ export async function buildSiteConfigurationForOlmClient(
sites: site, sites: site,
clientSitesAssociationsCache: association clientSitesAssociationsCache: association
} of sitesData) { } of sitesData) {
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
if (jitMode) {
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(({ siteResources }) => siteResources)
// ),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
});
continue;
}
if (!site.exitNodeId) { if (!site.exitNodeId) {
logger.warn( logger.warn(
`Site ${site.siteId} does not have exit node, skipping` `Site ${site.siteId} does not have exit node, skipping`
@@ -110,26 +166,6 @@ export async function buildSiteConfigurationForOlmClient(
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`; relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
} }
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
// Add site configuration to the array // Add site configuration to the array
siteConfigurations.push({ siteConfigurations.push({
siteId: site.siteId, siteId: site.siteId,

View File

@@ -17,6 +17,9 @@ import { getUserDeviceName } from "@server/db/names";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration"; import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { OlmErrorCodes, sendOlmError } from "./error"; import { OlmErrorCodes, sendOlmError } from "./error";
import { handleFingerprintInsertion } from "./fingerprintingUtils"; import { handleFingerprintInsertion } from "./fingerprintingUtils";
import { Alias } from "@server/lib/ip";
import { build } from "@server/build";
import { canCompress } from "@server/lib/clientVersionChecks";
export const handleOlmRegisterMessage: MessageHandler = async (context) => { export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!"); logger.info("Handling register olm message!");
@@ -207,6 +210,32 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
} }
} }
// Get all sites data
const sitesCountResult = await db
.select({ count: count() })
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
let jitMode = true;
if (sitesCount > 250 && build == "saas") {
// THIS IS THE MAX ON THE BUSINESS TIER
// we have too many sites
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
logger.info("Too many sites (%d), dropping into JIT mode", sitesCount);
jitMode = true;
}
logger.debug( logger.debug(
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}` `Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
); );
@@ -233,28 +262,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
await db await db
.update(clientSitesAssociationsCache) .update(clientSitesAssociationsCache)
.set({ .set({
isRelayed: relay == true isRelayed: relay == true,
isJitMode: jitMode
}) })
.where(eq(clientSitesAssociationsCache.clientId, client.clientId)); .where(eq(clientSitesAssociationsCache.clientId, client.clientId));
} }
// Get all sites data
const sitesCountResult = await db
.select({ count: count() })
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
// this prevents us from accepting a register from an olm that has not hole punched yet. // this prevents us from accepting a register from an olm that has not hole punched yet.
// the olm will pump the register so we can keep checking // the olm will pump the register so we can keep checking
// TODO: I still think there is a better way to do this rather than locking it out here but ??? // TODO: I still think there is a better way to do this rather than locking it out here but ???
@@ -269,15 +282,10 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
const siteConfigurations = await buildSiteConfigurationForOlmClient( const siteConfigurations = await buildSiteConfigurationForOlmClient(
client, client,
publicKey, publicKey,
relay relay,
jitMode
); );
// REMOVED THIS SO IT CREATES THE INTERFACE AND JUST WAITS FOR THE SITES
// if (siteConfigurations.length === 0) {
// logger.warn("No valid site configurations found");
// return;
// }
// Return connect message with all site configurations // Return connect message with all site configurations
return { return {
message: { message: {
@@ -288,6 +296,9 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
utilitySubnet: org.utilitySubnet utilitySubnet: org.utilitySubnet
} }
}, },
options: {
compress: canCompress(olm.version, "olm")
},
broadcast: false, broadcast: false,
excludeSender: false excludeSender: false
}; };

View File

@@ -18,7 +18,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
} }
if (!olm.clientId) { if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here? logger.warn("Olm has no client!");
return; return;
} }
@@ -41,7 +41,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
return; return;
} }
const { siteId } = message.data; const { siteId, chainId } = message.data;
// Get the site // Get the site
const [site] = await db const [site] = await db
@@ -90,7 +90,8 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
data: { data: {
siteId: siteId, siteId: siteId,
relayEndpoint: exitNode.endpoint, relayEndpoint: exitNode.endpoint,
relayPort: config.getRawConfig().gerbil.clients_start_port relayPort: config.getRawConfig().gerbil.clients_start_port,
chainId
} }
}, },
broadcast: false, broadcast: false,

View File

@@ -0,0 +1,241 @@
import {
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
exitNodes,
Site,
siteResources
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, Olm, sites } from "@server/db";
import { and, eq, or } from "drizzle-orm";
import logger from "@server/logger";
import { initPeerAddHandshake } from "./peers";
export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
context
) => {
logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client!"); // TODO: Maybe we create the site here?
return;
}
const clientId = olm.clientId;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
const { siteId, resourceId, chainId } = message.data;
let site: Site | null = null;
if (siteId) {
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (siteRes) {
site = siteRes;
}
}
if (resourceId && !site) {
const resources = await db
.select()
.from(siteResources)
.where(
and(
or(
eq(siteResources.niceId, resourceId),
eq(siteResources.alias, resourceId)
),
eq(siteResources.orgId, client.orgId)
)
);
if (!resources || resources.length === 0) {
logger.error(`handleOlmServerPeerAddMessage: Resource not found`);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
if (resources.length > 1) {
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
logger.error(
`handleOlmServerPeerAddMessage: Multiple resources found matching the criteria`
);
return;
}
const resource = resources[0];
const currentResourceAssociationCaches = await db
.select()
.from(clientSiteResourcesAssociationsCache)
.where(
and(
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
resource.siteResourceId
),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
if (currentResourceAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
const siteIdFromResource = resource.siteId;
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteIdFromResource));
if (!siteRes) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site} not found`
);
return;
}
site = siteRes;
}
if (!site) {
logger.error(`handleOlmServerPeerAddMessage: Site not found`);
return;
}
// check if the client can access this site using the cache
const currentSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
);
if (currentSiteAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
if (!site.exitNodeId) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
// get the exit node from the side
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
if (!exitNode) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
);
return;
}
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
{
siteId: site.siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
},
olm.olmId,
chainId
);
return;
};

View File

@@ -54,7 +54,7 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
return; return;
} }
const { siteId } = message.data; const { siteId, chainId } = message.data;
// get the site // get the site
const [site] = await db const [site] = await db
@@ -179,7 +179,8 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
), ),
aliases: generateAliasConfig( aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources) allSiteResources.map(({ siteResources }) => siteResources)
) ),
chainId: chainId,
} }
}, },
broadcast: false, broadcast: false,

View File

@@ -17,7 +17,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
} }
if (!olm.clientId) { if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here? logger.warn("Olm has no client!");
return; return;
} }
@@ -40,7 +40,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
return; return;
} }
const { siteId } = message.data; const { siteId, chainId } = message.data;
// Get the site // Get the site
const [site] = await db const [site] = await db
@@ -87,7 +87,8 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
type: "olm/wg/peer/unrelay", type: "olm/wg/peer/unrelay",
data: { data: {
siteId: siteId, siteId: siteId,
endpoint: site.endpoint endpoint: site.endpoint,
chainId
} }
}, },
broadcast: false, broadcast: false,

View File

@@ -11,3 +11,4 @@ export * from "./handleOlmServerPeerAddMessage";
export * from "./handleOlmUnRelayMessage"; export * from "./handleOlmUnRelayMessage";
export * from "./recoverOlmWithFingerprint"; export * from "./recoverOlmWithFingerprint";
export * from "./handleOlmDisconnectingMessage"; export * from "./handleOlmDisconnectingMessage";
export * from "./handleOlmServerInitAddPeerHandshake";

View File

@@ -1,8 +1,9 @@
import { sendToClient } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import { db, olms } from "@server/db"; import { clientSitesAssociationsCache, db, olms } from "@server/db";
import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config"; import config from "@server/lib/config";
import logger from "@server/logger"; import logger from "@server/logger";
import { eq } from "drizzle-orm"; import { and, eq } from "drizzle-orm";
import { Alias } from "yaml"; import { Alias } from "yaml";
export async function addPeer( export async function addPeer(
@@ -18,7 +19,8 @@ export async function addPeer(
remoteSubnets: string[] | null; // optional, comma-separated list of subnets that this site can access remoteSubnets: string[] | null; // optional, comma-separated list of subnets that this site can access
aliases: Alias[]; aliases: Alias[];
}, },
olmId?: string olmId?: string,
version?: string | null
) { ) {
if (!olmId) { if (!olmId) {
const [olm] = await db const [olm] = await db
@@ -30,6 +32,7 @@ export async function addPeer(
return; // ignore this because an olm might not be associated with the client anymore return; // ignore this because an olm might not be associated with the client anymore
} }
olmId = olm.olmId; olmId = olm.olmId;
version = olm.version;
} }
await sendToClient( await sendToClient(
@@ -48,7 +51,7 @@ export async function addPeer(
aliases: peer.aliases aliases: peer.aliases
} }
}, },
{ incrementConfigVersion: true } { incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => { ).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -60,7 +63,8 @@ export async function deletePeer(
clientId: number, clientId: number,
siteId: number, siteId: number,
publicKey: string, publicKey: string,
olmId?: string olmId?: string,
version?: string | null
) { ) {
if (!olmId) { if (!olmId) {
const [olm] = await db const [olm] = await db
@@ -72,6 +76,7 @@ export async function deletePeer(
return; return;
} }
olmId = olm.olmId; olmId = olm.olmId;
version = olm.version;
} }
await sendToClient( await sendToClient(
@@ -83,7 +88,7 @@ export async function deletePeer(
siteId: siteId siteId: siteId
} }
}, },
{ incrementConfigVersion: true } { incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => { ).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -103,7 +108,8 @@ export async function updatePeer(
remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that
aliases?: Alias[] | null; aliases?: Alias[] | null;
}, },
olmId?: string olmId?: string,
version?: string | null
) { ) {
if (!olmId) { if (!olmId) {
const [olm] = await db const [olm] = await db
@@ -115,6 +121,7 @@ export async function updatePeer(
return; return;
} }
olmId = olm.olmId; olmId = olm.olmId;
version = olm.version;
} }
await sendToClient( await sendToClient(
@@ -132,7 +139,7 @@ export async function updatePeer(
aliases: peer.aliases aliases: peer.aliases
} }
}, },
{ incrementConfigVersion: true } { incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => { ).catch((error) => {
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
@@ -149,7 +156,8 @@ export async function initPeerAddHandshake(
endpoint: string; endpoint: string;
}; };
}, },
olmId?: string olmId?: string,
chainId?: string
) { ) {
if (!olmId) { if (!olmId) {
const [olm] = await db const [olm] = await db
@@ -173,7 +181,8 @@ export async function initPeerAddHandshake(
publicKey: peer.exitNode.publicKey, publicKey: peer.exitNode.publicKey,
relayPort: config.getRawConfig().gerbil.clients_start_port, relayPort: config.getRawConfig().gerbil.clients_start_port,
endpoint: peer.exitNode.endpoint endpoint: peer.exitNode.endpoint
} },
chainId
} }
}, },
{ incrementConfigVersion: true } { incrementConfigVersion: true }
@@ -181,6 +190,17 @@ export async function initPeerAddHandshake(
logger.warn(`Error sending message:`, error); logger.warn(`Error sending message:`, error);
}); });
// update the clientSiteAssociationsCache to make the isJitMode flag false so that JIT mode is disabled for this site if it restarts or something after the connection
await db
.update(clientSitesAssociationsCache)
.set({ isJitMode: false })
.where(
and(
eq(clientSitesAssociationsCache.clientId, clientId),
eq(clientSitesAssociationsCache.siteId, peer.siteId)
)
);
logger.info( logger.info(
`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}` `Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`
); );

View File

@@ -1,9 +1,17 @@
import { Client, db, exitNodes, Olm, sites, clientSitesAssociationsCache } from "@server/db"; import {
Client,
db,
exitNodes,
Olm,
sites,
clientSitesAssociationsCache
} from "@server/db";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration"; import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { sendToClient } from "#dynamic/routers/ws"; import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger"; import logger from "@server/logger";
import { eq, inArray } from "drizzle-orm"; import { eq, inArray } from "drizzle-orm";
import config from "@server/lib/config"; import config from "@server/lib/config";
import { canCompress } from "@server/lib/clientVersionChecks";
export async function sendOlmSyncMessage(olm: Olm, client: Client) { export async function sendOlmSyncMessage(olm: Olm, client: Client) {
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT // NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
@@ -17,10 +25,7 @@ export async function sendOlmSyncMessage(olm: Olm, client: Client) {
const clientSites = await db const clientSites = await db
.select() .select()
.from(clientSitesAssociationsCache) .from(clientSitesAssociationsCache)
.innerJoin( .innerJoin(sites, eq(sites.siteId, clientSitesAssociationsCache.siteId))
sites,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId)); .where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract unique exit node IDs // Extract unique exit node IDs
@@ -68,13 +73,20 @@ export async function sendOlmSyncMessage(olm: Olm, client: Client) {
logger.debug("sendOlmSyncMessage: sending sync message"); logger.debug("sendOlmSyncMessage: sending sync message");
await sendToClient(olm.olmId, { await sendToClient(
olm.olmId,
{
type: "olm/sync", type: "olm/sync",
data: { data: {
sites: siteConfigurations, sites: siteConfigurations,
exitNodes: exitNodesData exitNodes: exitNodesData
} }
}).catch((error) => { },
{
compress: canCompress(olm.version, "olm")
}
).catch((error) => {
logger.warn(`Error sending olm sync message:`, error); logger.warn(`Error sending olm sync message:`, error);
}); });
} }

View File

@@ -620,7 +620,7 @@ export async function handleMessagingForUpdatedSiteResource(
await updateTargets(newt.newtId, { await updateTargets(newt.newtId, {
oldTargets: oldTargets, oldTargets: oldTargets,
newTargets: newTargets newTargets: newTargets
}); }, newt.version);
} }
const olmJobs: Promise<void>[] = []; const olmJobs: Promise<void>[] = [];

View File

@@ -264,7 +264,7 @@ export async function createTarget(
newTarget, newTarget,
healthCheck, healthCheck,
resource.protocol, resource.protocol,
resource.proxyPort newt.version
); );
} }
} }

View File

@@ -262,7 +262,7 @@ export async function updateTarget(
[updatedTarget], [updatedTarget],
[updatedHc], [updatedHc],
resource.protocol, resource.protocol,
resource.proxyPort newt.version
); );
} }
} }

View File

@@ -6,7 +6,8 @@ import {
handleDockerContainersMessage, handleDockerContainersMessage,
handleNewtPingRequestMessage, handleNewtPingRequestMessage,
handleApplyBlueprintMessage, handleApplyBlueprintMessage,
handleNewtPingMessage handleNewtPingMessage,
startNewtOfflineChecker
} from "../newt"; } from "../newt";
import { import {
handleOlmRegisterMessage, handleOlmRegisterMessage,
@@ -15,7 +16,8 @@ import {
startOlmOfflineChecker, startOlmOfflineChecker,
handleOlmServerPeerAddMessage, handleOlmServerPeerAddMessage,
handleOlmUnRelayMessage, handleOlmUnRelayMessage,
handleOlmDisconnecingMessage handleOlmDisconnecingMessage,
handleOlmServerInitAddPeerHandshake
} from "../olm"; } from "../olm";
import { handleHealthcheckStatusMessage } from "../target"; import { handleHealthcheckStatusMessage } from "../target";
import { handleRoundTripMessage } from "./handleRoundTripMessage"; import { handleRoundTripMessage } from "./handleRoundTripMessage";
@@ -23,6 +25,7 @@ import { MessageHandler } from "./types";
export const messageHandlers: Record<string, MessageHandler> = { export const messageHandlers: Record<string, MessageHandler> = {
"olm/wg/server/peer/add": handleOlmServerPeerAddMessage, "olm/wg/server/peer/add": handleOlmServerPeerAddMessage,
"olm/wg/server/peer/init": handleOlmServerInitAddPeerHandshake,
"olm/wg/register": handleOlmRegisterMessage, "olm/wg/register": handleOlmRegisterMessage,
"olm/wg/relay": handleOlmRelayMessage, "olm/wg/relay": handleOlmRelayMessage,
"olm/wg/unrelay": handleOlmUnRelayMessage, "olm/wg/unrelay": handleOlmUnRelayMessage,
@@ -41,3 +44,4 @@ export const messageHandlers: Record<string, MessageHandler> = {
}; };
startOlmOfflineChecker(); // this is to handle the offline check for olms startOlmOfflineChecker(); // this is to handle the offline check for olms
startNewtOfflineChecker(); // this is to handle the offline check for newts

View File

@@ -24,7 +24,7 @@ export interface AuthenticatedWebSocket extends WebSocket {
clientType?: ClientType; clientType?: ClientType;
connectionId?: string; connectionId?: string;
isFullyConnected?: boolean; isFullyConnected?: boolean;
pendingMessages?: Buffer[]; pendingMessages?: { data: Buffer; isBinary: boolean }[];
configVersion?: number; configVersion?: number;
} }
@@ -73,6 +73,7 @@ export type MessageHandler = (
// Options for sending messages with config version tracking // Options for sending messages with config version tracking
export interface SendMessageOptions { export interface SendMessageOptions {
incrementConfigVersion?: boolean; incrementConfigVersion?: boolean;
compress?: boolean;
} }
// Redis message type for cross-node communication // Redis message type for cross-node communication

View File

@@ -1,8 +1,9 @@
import { Router, Request, Response } from "express"; import { Router, Request, Response } from "express";
import zlib from "zlib";
import { Server as HttpServer } from "http"; import { Server as HttpServer } from "http";
import { WebSocket, WebSocketServer } from "ws"; import { WebSocket, WebSocketServer } from "ws";
import { Socket } from "net"; import { Socket } from "net";
import { Newt, newts, NewtSession, olms, Olm, OlmSession } from "@server/db"; import { Newt, newts, NewtSession, olms, Olm, OlmSession, sites } from "@server/db";
import { eq } from "drizzle-orm"; import { eq } from "drizzle-orm";
import { db } from "@server/db"; import { db } from "@server/db";
import { validateNewtSessionToken } from "@server/auth/sessions/newt"; import { validateNewtSessionToken } from "@server/auth/sessions/newt";
@@ -116,11 +117,20 @@ const sendToClientLocal = async (
}; };
const messageString = JSON.stringify(messageWithVersion); const messageString = JSON.stringify(messageWithVersion);
if (options.compress) {
const compressed = zlib.gzipSync(Buffer.from(messageString, "utf8"));
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(messageString); client.send(messageString);
} }
}); });
}
return true; return true;
}; };
@@ -147,12 +157,23 @@ const broadcastToAllExceptLocal = async (
...message, ...message,
configVersion configVersion
}; };
if (options.compress) {
const compressed = zlib.gzipSync(
Buffer.from(JSON.stringify(messageWithVersion), "utf8")
);
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {
client.send(compressed);
}
});
} else {
clients.forEach((client) => { clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) { if (client.readyState === WebSocket.OPEN) {
client.send(JSON.stringify(messageWithVersion)); client.send(JSON.stringify(messageWithVersion));
} }
}); });
} }
}
}); });
}; };
@@ -286,9 +307,12 @@ const setupConnection = async (
clientType === "newt" ? (client as Newt).newtId : (client as Olm).olmId; clientType === "newt" ? (client as Newt).newtId : (client as Olm).olmId;
await addClient(clientType, clientId, ws); await addClient(clientType, clientId, ws);
ws.on("message", async (data) => { ws.on("message", async (data, isBinary) => {
try { try {
const message: WSMessage = JSON.parse(data.toString()); const messageBuffer = isBinary
? zlib.gunzipSync(data as Buffer)
: (data as Buffer);
const message: WSMessage = JSON.parse(messageBuffer.toString());
if (!message.type || typeof message.type !== "string") { if (!message.type || typeof message.type !== "string") {
throw new Error( throw new Error(
@@ -356,6 +380,31 @@ const setupConnection = async (
); );
}); });
// Handle WebSocket protocol-level pings from older newt clients that do
// not send application-level "newt/ping" messages. Update the site's
// online state and lastPing timestamp so the offline checker treats them
// the same as modern newt clients.
if (clientType === "newt") {
const newtClient = client as Newt;
ws.on("ping", async () => {
if (!newtClient.siteId) return;
try {
await db
.update(sites)
.set({
online: true,
lastPing: Math.floor(Date.now() / 1000)
})
.where(eq(sites.siteId, newtClient.siteId));
} catch (error) {
logger.error(
"Error updating newt site online state on WS ping",
{ error }
);
}
});
}
ws.on("error", (error: Error) => { ws.on("error", (error: Error) => {
logger.error( logger.error(
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`, `WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,