mirror of
https://github.com/fosrl/pangolin.git
synced 2026-03-14 22:56:37 +00:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a43f1ef4b | ||
|
|
75ab074805 | ||
|
|
dc4e0253de | ||
|
|
cccf236042 | ||
|
|
63fd63c65c | ||
|
|
beee1d692d | ||
|
|
fde786ca84 | ||
|
|
3086fdd064 | ||
|
|
cf5fb8dc33 | ||
|
|
0503c6e66e | ||
|
|
9405b0b70a | ||
|
|
2a5c9465e9 | ||
|
|
f36b66e397 | ||
|
|
1bfff630bf | ||
|
|
c73a39f797 |
@@ -1,6 +1,10 @@
|
|||||||
|
import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage";
|
||||||
|
import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth";
|
||||||
import { cleanup as wsCleanup } from "#dynamic/routers/ws";
|
import { cleanup as wsCleanup } from "#dynamic/routers/ws";
|
||||||
|
|
||||||
async function cleanup() {
|
async function cleanup() {
|
||||||
|
await flushBandwidthToDb();
|
||||||
|
await flushSiteBandwidthToDb();
|
||||||
await wsCleanup();
|
await wsCleanup();
|
||||||
|
|
||||||
process.exit(0);
|
process.exit(0);
|
||||||
|
|||||||
@@ -89,6 +89,7 @@ export const sites = pgTable("sites", {
|
|||||||
lastBandwidthUpdate: varchar("lastBandwidthUpdate"),
|
lastBandwidthUpdate: varchar("lastBandwidthUpdate"),
|
||||||
type: varchar("type").notNull(), // "newt" or "wireguard"
|
type: varchar("type").notNull(), // "newt" or "wireguard"
|
||||||
online: boolean("online").notNull().default(false),
|
online: boolean("online").notNull().default(false),
|
||||||
|
lastPing: integer("lastPing"),
|
||||||
address: varchar("address"),
|
address: varchar("address"),
|
||||||
endpoint: varchar("endpoint"),
|
endpoint: varchar("endpoint"),
|
||||||
publicKey: varchar("publicKey"),
|
publicKey: varchar("publicKey"),
|
||||||
@@ -721,6 +722,7 @@ export const clientSitesAssociationsCache = pgTable(
|
|||||||
.notNull(),
|
.notNull(),
|
||||||
siteId: integer("siteId").notNull(),
|
siteId: integer("siteId").notNull(),
|
||||||
isRelayed: boolean("isRelayed").notNull().default(false),
|
isRelayed: boolean("isRelayed").notNull().default(false),
|
||||||
|
isJitMode: boolean("isJitMode").notNull().default(false),
|
||||||
endpoint: varchar("endpoint"),
|
endpoint: varchar("endpoint"),
|
||||||
publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
|
publicKey: varchar("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ export const sites = sqliteTable("sites", {
|
|||||||
lastBandwidthUpdate: text("lastBandwidthUpdate"),
|
lastBandwidthUpdate: text("lastBandwidthUpdate"),
|
||||||
type: text("type").notNull(), // "newt" or "wireguard"
|
type: text("type").notNull(), // "newt" or "wireguard"
|
||||||
online: integer("online", { mode: "boolean" }).notNull().default(false),
|
online: integer("online", { mode: "boolean" }).notNull().default(false),
|
||||||
|
lastPing: integer("lastPing"),
|
||||||
|
|
||||||
// exit node stuff that is how to connect to the site when it has a wg server
|
// exit node stuff that is how to connect to the site when it has a wg server
|
||||||
address: text("address"), // this is the address of the wireguard interface in newt
|
address: text("address"), // this is the address of the wireguard interface in newt
|
||||||
@@ -410,6 +411,9 @@ export const clientSitesAssociationsCache = sqliteTable(
|
|||||||
isRelayed: integer("isRelayed", { mode: "boolean" })
|
isRelayed: integer("isRelayed", { mode: "boolean" })
|
||||||
.notNull()
|
.notNull()
|
||||||
.default(false),
|
.default(false),
|
||||||
|
isJitMode: integer("isJitMode", { mode: "boolean" })
|
||||||
|
.notNull()
|
||||||
|
.default(false),
|
||||||
endpoint: text("endpoint"),
|
endpoint: text("endpoint"),
|
||||||
publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
|
publicKey: text("publicKey") // this will act as the session's public key for hole punching so we can track when it changes
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ export async function applyBlueprint({
|
|||||||
[target],
|
[target],
|
||||||
matchingHealthcheck ? [matchingHealthcheck] : [],
|
matchingHealthcheck ? [matchingHealthcheck] : [],
|
||||||
result.proxyResource.protocol,
|
result.proxyResource.protocol,
|
||||||
result.proxyResource.proxyPort
|
site.newt.version
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
20
server/lib/clientVersionChecks.ts
Normal file
20
server/lib/clientVersionChecks.ts
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
import semver from "semver";
|
||||||
|
|
||||||
|
export function canCompress(
|
||||||
|
clientVersion: string | null | undefined,
|
||||||
|
type: "newt" | "olm"
|
||||||
|
): boolean {
|
||||||
|
try {
|
||||||
|
if (!clientVersion) return false;
|
||||||
|
// check if it is a valid semver
|
||||||
|
if (!semver.valid(clientVersion)) return false;
|
||||||
|
if (type === "newt") {
|
||||||
|
return semver.gte(clientVersion, "1.10.3");
|
||||||
|
} else if (type === "olm") {
|
||||||
|
return semver.gte(clientVersion, "1.4.3");
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
} catch {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -477,6 +477,7 @@ async function handleMessagesForSiteClients(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isAdd) {
|
if (isAdd) {
|
||||||
|
// TODO: if we are in jit mode here should we really be sending this?
|
||||||
await initPeerAddHandshake(
|
await initPeerAddHandshake(
|
||||||
// this will kick off the add peer process for the client
|
// this will kick off the add peer process for the client
|
||||||
client.clientId,
|
client.clientId,
|
||||||
@@ -669,7 +670,11 @@ async function handleSubnetProxyTargetUpdates(
|
|||||||
`Adding ${targetsToAdd.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}`
|
`Adding ${targetsToAdd.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}`
|
||||||
);
|
);
|
||||||
proxyJobs.push(
|
proxyJobs.push(
|
||||||
addSubnetProxyTargets(newt.newtId, targetsToAdd)
|
addSubnetProxyTargets(
|
||||||
|
newt.newtId,
|
||||||
|
targetsToAdd,
|
||||||
|
newt.version
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -705,7 +710,11 @@ async function handleSubnetProxyTargetUpdates(
|
|||||||
`Removing ${targetsToRemove.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}`
|
`Removing ${targetsToRemove.length} subnet proxy targets for siteResource ${siteResource.siteResourceId}`
|
||||||
);
|
);
|
||||||
proxyJobs.push(
|
proxyJobs.push(
|
||||||
removeSubnetProxyTargets(newt.newtId, targetsToRemove)
|
removeSubnetProxyTargets(
|
||||||
|
newt.newtId,
|
||||||
|
targetsToRemove,
|
||||||
|
newt.version
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1080,6 +1089,7 @@ async function handleMessagesForClientSites(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: if we are in jit mode here should we really be sending this?
|
||||||
await initPeerAddHandshake(
|
await initPeerAddHandshake(
|
||||||
// this will kick off the add peer process for the client
|
// this will kick off the add peer process for the client
|
||||||
client.clientId,
|
client.clientId,
|
||||||
@@ -1146,7 +1156,7 @@ async function handleMessagesForClientResources(
|
|||||||
// Add subnet proxy targets for each site
|
// Add subnet proxy targets for each site
|
||||||
for (const [siteId, resources] of addedBySite.entries()) {
|
for (const [siteId, resources] of addedBySite.entries()) {
|
||||||
const [newt] = await trx
|
const [newt] = await trx
|
||||||
.select({ newtId: newts.newtId })
|
.select({ newtId: newts.newtId, version: newts.version })
|
||||||
.from(newts)
|
.from(newts)
|
||||||
.where(eq(newts.siteId, siteId))
|
.where(eq(newts.siteId, siteId))
|
||||||
.limit(1);
|
.limit(1);
|
||||||
@@ -1168,7 +1178,13 @@ async function handleMessagesForClientResources(
|
|||||||
]);
|
]);
|
||||||
|
|
||||||
if (targets.length > 0) {
|
if (targets.length > 0) {
|
||||||
proxyJobs.push(addSubnetProxyTargets(newt.newtId, targets));
|
proxyJobs.push(
|
||||||
|
addSubnetProxyTargets(
|
||||||
|
newt.newtId,
|
||||||
|
targets,
|
||||||
|
newt.version
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -1217,7 +1233,7 @@ async function handleMessagesForClientResources(
|
|||||||
// Remove subnet proxy targets for each site
|
// Remove subnet proxy targets for each site
|
||||||
for (const [siteId, resources] of removedBySite.entries()) {
|
for (const [siteId, resources] of removedBySite.entries()) {
|
||||||
const [newt] = await trx
|
const [newt] = await trx
|
||||||
.select({ newtId: newts.newtId })
|
.select({ newtId: newts.newtId, version: newts.version })
|
||||||
.from(newts)
|
.from(newts)
|
||||||
.where(eq(newts.siteId, siteId))
|
.where(eq(newts.siteId, siteId))
|
||||||
.limit(1);
|
.limit(1);
|
||||||
@@ -1240,7 +1256,11 @@ async function handleMessagesForClientResources(
|
|||||||
|
|
||||||
if (targets.length > 0) {
|
if (targets.length > 0) {
|
||||||
proxyJobs.push(
|
proxyJobs.push(
|
||||||
removeSubnetProxyTargets(newt.newtId, targets)
|
removeSubnetProxyTargets(
|
||||||
|
newt.newtId,
|
||||||
|
targets,
|
||||||
|
newt.version
|
||||||
|
)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,12 @@
|
|||||||
|
|
||||||
import { rateLimitService } from "#private/lib/rateLimit";
|
import { rateLimitService } from "#private/lib/rateLimit";
|
||||||
import { cleanup as wsCleanup } from "#private/routers/ws";
|
import { cleanup as wsCleanup } from "#private/routers/ws";
|
||||||
|
import { flushBandwidthToDb } from "@server/routers/newt/handleReceiveBandwidthMessage";
|
||||||
|
import { flushSiteBandwidthToDb } from "@server/routers/gerbil/receiveBandwidth";
|
||||||
|
|
||||||
async function cleanup() {
|
async function cleanup() {
|
||||||
|
await flushBandwidthToDb();
|
||||||
|
await flushSiteBandwidthToDb();
|
||||||
await rateLimitService.cleanup();
|
await rateLimitService.cleanup();
|
||||||
await wsCleanup();
|
await wsCleanup();
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ import HttpCode from "@server/types/HttpCode";
|
|||||||
import createHttpError from "http-errors";
|
import createHttpError from "http-errors";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import { fromError } from "zod-validation-error";
|
import { fromError } from "zod-validation-error";
|
||||||
import { OpenAPITags, registry } from "@server/openApi";
|
|
||||||
import { eq, or, and } from "drizzle-orm";
|
import { eq, or, and } from "drizzle-orm";
|
||||||
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
|
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
|
||||||
import { signPublicKey, getOrgCAKeys } from "@server/lib/sshCA";
|
import { signPublicKey, getOrgCAKeys } from "@server/lib/sshCA";
|
||||||
@@ -64,6 +63,7 @@ export type SignSshKeyResponse = {
|
|||||||
sshUsername: string;
|
sshUsername: string;
|
||||||
sshHost: string;
|
sshHost: string;
|
||||||
resourceId: number;
|
resourceId: number;
|
||||||
|
siteId: number;
|
||||||
keyId: string;
|
keyId: string;
|
||||||
validPrincipals: string[];
|
validPrincipals: string[];
|
||||||
validAfter: string;
|
validAfter: string;
|
||||||
@@ -453,6 +453,7 @@ export async function signSshKey(
|
|||||||
sshUsername: usernameToUse,
|
sshUsername: usernameToUse,
|
||||||
sshHost: sshHost,
|
sshHost: sshHost,
|
||||||
resourceId: resource.siteResourceId,
|
resourceId: resource.siteResourceId,
|
||||||
|
siteId: resource.siteId,
|
||||||
keyId: cert.keyId,
|
keyId: cert.keyId,
|
||||||
validPrincipals: cert.validPrincipals,
|
validPrincipals: cert.validPrincipals,
|
||||||
validAfter: cert.validAfter.toISOString(),
|
validAfter: cert.validAfter.toISOString(),
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import { Router, Request, Response } from "express";
|
import { Router, Request, Response } from "express";
|
||||||
|
import zlib from "zlib";
|
||||||
import { Server as HttpServer } from "http";
|
import { Server as HttpServer } from "http";
|
||||||
import { WebSocket, WebSocketServer } from "ws";
|
import { WebSocket, WebSocketServer } from "ws";
|
||||||
import { Socket } from "net";
|
import { Socket } from "net";
|
||||||
@@ -24,7 +25,8 @@ import {
|
|||||||
OlmSession,
|
OlmSession,
|
||||||
RemoteExitNode,
|
RemoteExitNode,
|
||||||
RemoteExitNodeSession,
|
RemoteExitNodeSession,
|
||||||
remoteExitNodes
|
remoteExitNodes,
|
||||||
|
sites
|
||||||
} from "@server/db";
|
} from "@server/db";
|
||||||
import { eq } from "drizzle-orm";
|
import { eq } from "drizzle-orm";
|
||||||
import { db } from "@server/db";
|
import { db } from "@server/db";
|
||||||
@@ -57,11 +59,13 @@ const MAX_PENDING_MESSAGES = 50; // Maximum messages to queue during connection
|
|||||||
const processMessage = async (
|
const processMessage = async (
|
||||||
ws: AuthenticatedWebSocket,
|
ws: AuthenticatedWebSocket,
|
||||||
data: Buffer,
|
data: Buffer,
|
||||||
|
isBinary: boolean,
|
||||||
clientId: string,
|
clientId: string,
|
||||||
clientType: ClientType
|
clientType: ClientType
|
||||||
): Promise<void> => {
|
): Promise<void> => {
|
||||||
try {
|
try {
|
||||||
const message: WSMessage = JSON.parse(data.toString());
|
const messageBuffer = isBinary ? zlib.gunzipSync(data) : data;
|
||||||
|
const message: WSMessage = JSON.parse(messageBuffer.toString());
|
||||||
|
|
||||||
// logger.debug(
|
// logger.debug(
|
||||||
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
|
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
|
||||||
@@ -76,7 +80,7 @@ const processMessage = async (
|
|||||||
clientId,
|
clientId,
|
||||||
message.type, // Pass message type for granular limiting
|
message.type, // Pass message type for granular limiting
|
||||||
100, // max requests per window
|
100, // max requests per window
|
||||||
20, // max requests per message type per window
|
100, // max requests per message type per window
|
||||||
60 * 1000 // window in milliseconds
|
60 * 1000 // window in milliseconds
|
||||||
);
|
);
|
||||||
if (rateLimitResult.isLimited) {
|
if (rateLimitResult.isLimited) {
|
||||||
@@ -163,8 +167,16 @@ const processPendingMessages = async (
|
|||||||
);
|
);
|
||||||
|
|
||||||
const jobs = [];
|
const jobs = [];
|
||||||
for (const messageData of ws.pendingMessages) {
|
for (const pending of ws.pendingMessages) {
|
||||||
jobs.push(processMessage(ws, messageData, clientId, clientType));
|
jobs.push(
|
||||||
|
processMessage(
|
||||||
|
ws,
|
||||||
|
pending.data,
|
||||||
|
pending.isBinary,
|
||||||
|
clientId,
|
||||||
|
clientType
|
||||||
|
)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
await Promise.all(jobs);
|
await Promise.all(jobs);
|
||||||
@@ -325,7 +337,9 @@ const addClient = async (
|
|||||||
// Check Redis first if enabled
|
// Check Redis first if enabled
|
||||||
if (redisManager.isRedisEnabled()) {
|
if (redisManager.isRedisEnabled()) {
|
||||||
try {
|
try {
|
||||||
const redisVersion = await redisManager.get(getConfigVersionKey(clientId));
|
const redisVersion = await redisManager.get(
|
||||||
|
getConfigVersionKey(clientId)
|
||||||
|
);
|
||||||
if (redisVersion !== null) {
|
if (redisVersion !== null) {
|
||||||
configVersion = parseInt(redisVersion, 10);
|
configVersion = parseInt(redisVersion, 10);
|
||||||
// Sync to local cache
|
// Sync to local cache
|
||||||
@@ -337,7 +351,10 @@ const addClient = async (
|
|||||||
} else {
|
} else {
|
||||||
// Use local cache version and sync to Redis
|
// Use local cache version and sync to Redis
|
||||||
configVersion = clientConfigVersions.get(clientId) || 0;
|
configVersion = clientConfigVersions.get(clientId) || 0;
|
||||||
await redisManager.set(getConfigVersionKey(clientId), configVersion.toString());
|
await redisManager.set(
|
||||||
|
getConfigVersionKey(clientId),
|
||||||
|
configVersion.toString()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error("Failed to get/set config version in Redis:", error);
|
logger.error("Failed to get/set config version in Redis:", error);
|
||||||
@@ -432,7 +449,9 @@ const removeClient = async (
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Helper to get the current config version for a client
|
// Helper to get the current config version for a client
|
||||||
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
|
const getClientConfigVersion = async (
|
||||||
|
clientId: string
|
||||||
|
): Promise<number | undefined> => {
|
||||||
// Try Redis first if available
|
// Try Redis first if available
|
||||||
if (redisManager.isRedisEnabled()) {
|
if (redisManager.isRedisEnabled()) {
|
||||||
try {
|
try {
|
||||||
@@ -502,11 +521,26 @@ const sendToClientLocal = async (
|
|||||||
};
|
};
|
||||||
|
|
||||||
const messageString = JSON.stringify(messageWithVersion);
|
const messageString = JSON.stringify(messageWithVersion);
|
||||||
|
if (options.compress) {
|
||||||
|
logger.debug(
|
||||||
|
`Message size before compression: ${messageString.length} bytes`
|
||||||
|
);
|
||||||
|
const compressed = zlib.gzipSync(Buffer.from(messageString, "utf8"));
|
||||||
|
logger.debug(
|
||||||
|
`Message size after compression: ${compressed.length} bytes`
|
||||||
|
);
|
||||||
|
clients.forEach((client) => {
|
||||||
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
|
client.send(compressed);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
clients.forEach((client) => {
|
clients.forEach((client) => {
|
||||||
if (client.readyState === WebSocket.OPEN) {
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
client.send(messageString);
|
client.send(messageString);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
@@ -532,6 +566,16 @@ const broadcastToAllExceptLocal = async (
|
|||||||
configVersion
|
configVersion
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (options.compress) {
|
||||||
|
const compressed = zlib.gzipSync(
|
||||||
|
Buffer.from(JSON.stringify(messageWithVersion), "utf8")
|
||||||
|
);
|
||||||
|
clients.forEach((client) => {
|
||||||
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
|
client.send(compressed);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
clients.forEach((client) => {
|
clients.forEach((client) => {
|
||||||
if (client.readyState === WebSocket.OPEN) {
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
client.send(JSON.stringify(messageWithVersion));
|
client.send(JSON.stringify(messageWithVersion));
|
||||||
@@ -539,6 +583,7 @@ const broadcastToAllExceptLocal = async (
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Cross-node message sending (via Redis)
|
// Cross-node message sending (via Redis)
|
||||||
@@ -762,7 +807,7 @@ const setupConnection = async (
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set up message handler FIRST to prevent race condition
|
// Set up message handler FIRST to prevent race condition
|
||||||
ws.on("message", async (data) => {
|
ws.on("message", async (data, isBinary) => {
|
||||||
if (!ws.isFullyConnected) {
|
if (!ws.isFullyConnected) {
|
||||||
// Queue message for later processing with limits
|
// Queue message for later processing with limits
|
||||||
ws.pendingMessages = ws.pendingMessages || [];
|
ws.pendingMessages = ws.pendingMessages || [];
|
||||||
@@ -777,11 +822,17 @@ const setupConnection = async (
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
`Queueing message from ${clientType.toUpperCase()} ID: ${clientId} (connection not fully established)`
|
`Queueing message from ${clientType.toUpperCase()} ID: ${clientId} (connection not fully established)`
|
||||||
);
|
);
|
||||||
ws.pendingMessages.push(data as Buffer);
|
ws.pendingMessages.push({ data: data as Buffer, isBinary });
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
await processMessage(ws, data as Buffer, clientId, clientType);
|
await processMessage(
|
||||||
|
ws,
|
||||||
|
data as Buffer,
|
||||||
|
isBinary,
|
||||||
|
clientId,
|
||||||
|
clientType
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Set up other event handlers before async operations
|
// Set up other event handlers before async operations
|
||||||
@@ -796,6 +847,31 @@ const setupConnection = async (
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Handle WebSocket protocol-level pings from older newt clients that do
|
||||||
|
// not send application-level "newt/ping" messages. Update the site's
|
||||||
|
// online state and lastPing timestamp so the offline checker treats them
|
||||||
|
// the same as modern newt clients.
|
||||||
|
if (clientType === "newt") {
|
||||||
|
const newtClient = client as Newt;
|
||||||
|
ws.on("ping", async () => {
|
||||||
|
if (!newtClient.siteId) return;
|
||||||
|
try {
|
||||||
|
await db
|
||||||
|
.update(sites)
|
||||||
|
.set({
|
||||||
|
online: true,
|
||||||
|
lastPing: Math.floor(Date.now() / 1000)
|
||||||
|
})
|
||||||
|
.where(eq(sites.siteId, newtClient.siteId));
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(
|
||||||
|
"Error updating newt site online state on WS ping",
|
||||||
|
{ error }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
ws.on("error", (error: Error) => {
|
ws.on("error", (error: Error) => {
|
||||||
logger.error(
|
logger.error(
|
||||||
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,
|
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,
|
||||||
|
|||||||
@@ -1,51 +1,38 @@
|
|||||||
import { sendToClient } from "#dynamic/routers/ws";
|
import { sendToClient } from "#dynamic/routers/ws";
|
||||||
import { db, olms, Transaction } from "@server/db";
|
import { db, olms, Transaction } from "@server/db";
|
||||||
|
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||||
import { Alias, SubnetProxyTarget } from "@server/lib/ip";
|
import { Alias, SubnetProxyTarget } from "@server/lib/ip";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import { eq } from "drizzle-orm";
|
import { eq } from "drizzle-orm";
|
||||||
|
|
||||||
const BATCH_SIZE = 50;
|
export async function addTargets(
|
||||||
const BATCH_DELAY_MS = 50;
|
newtId: string,
|
||||||
|
targets: SubnetProxyTarget[],
|
||||||
function sleep(ms: number): Promise<void> {
|
version?: string | null
|
||||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
) {
|
||||||
}
|
await sendToClient(
|
||||||
|
newtId,
|
||||||
function chunkArray<T>(array: T[], size: number): T[][] {
|
{
|
||||||
const chunks: T[][] = [];
|
|
||||||
for (let i = 0; i < array.length; i += size) {
|
|
||||||
chunks.push(array.slice(i, i + size));
|
|
||||||
}
|
|
||||||
return chunks;
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
|
|
||||||
const batches = chunkArray(targets, BATCH_SIZE);
|
|
||||||
for (let i = 0; i < batches.length; i++) {
|
|
||||||
if (i > 0) {
|
|
||||||
await sleep(BATCH_DELAY_MS);
|
|
||||||
}
|
|
||||||
await sendToClient(newtId, {
|
|
||||||
type: `newt/wg/targets/add`,
|
type: `newt/wg/targets/add`,
|
||||||
data: batches[i]
|
data: targets
|
||||||
}, { incrementConfigVersion: true });
|
},
|
||||||
}
|
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function removeTargets(
|
export async function removeTargets(
|
||||||
newtId: string,
|
newtId: string,
|
||||||
targets: SubnetProxyTarget[]
|
targets: SubnetProxyTarget[],
|
||||||
|
version?: string | null
|
||||||
) {
|
) {
|
||||||
const batches = chunkArray(targets, BATCH_SIZE);
|
await sendToClient(
|
||||||
for (let i = 0; i < batches.length; i++) {
|
newtId,
|
||||||
if (i > 0) {
|
{
|
||||||
await sleep(BATCH_DELAY_MS);
|
|
||||||
}
|
|
||||||
await sendToClient(newtId, {
|
|
||||||
type: `newt/wg/targets/remove`,
|
type: `newt/wg/targets/remove`,
|
||||||
data: batches[i]
|
data: targets
|
||||||
},{ incrementConfigVersion: true });
|
},
|
||||||
}
|
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function updateTargets(
|
export async function updateTargets(
|
||||||
@@ -53,34 +40,31 @@ export async function updateTargets(
|
|||||||
targets: {
|
targets: {
|
||||||
oldTargets: SubnetProxyTarget[];
|
oldTargets: SubnetProxyTarget[];
|
||||||
newTargets: SubnetProxyTarget[];
|
newTargets: SubnetProxyTarget[];
|
||||||
}
|
},
|
||||||
|
version?: string | null
|
||||||
) {
|
) {
|
||||||
const oldBatches = chunkArray(targets.oldTargets, BATCH_SIZE);
|
await sendToClient(
|
||||||
const newBatches = chunkArray(targets.newTargets, BATCH_SIZE);
|
newtId,
|
||||||
const maxBatches = Math.max(oldBatches.length, newBatches.length);
|
{
|
||||||
|
|
||||||
for (let i = 0; i < maxBatches; i++) {
|
|
||||||
if (i > 0) {
|
|
||||||
await sleep(BATCH_DELAY_MS);
|
|
||||||
}
|
|
||||||
await sendToClient(newtId, {
|
|
||||||
type: `newt/wg/targets/update`,
|
type: `newt/wg/targets/update`,
|
||||||
data: {
|
data: {
|
||||||
oldTargets: oldBatches[i] || [],
|
oldTargets: targets.oldTargets,
|
||||||
newTargets: newBatches[i] || []
|
newTargets: targets.newTargets
|
||||||
}
|
}
|
||||||
}, { incrementConfigVersion: true }).catch((error) => {
|
},
|
||||||
|
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
|
||||||
|
).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
export async function addPeerData(
|
export async function addPeerData(
|
||||||
clientId: number,
|
clientId: number,
|
||||||
siteId: number,
|
siteId: number,
|
||||||
remoteSubnets: string[],
|
remoteSubnets: string[],
|
||||||
aliases: Alias[],
|
aliases: Alias[],
|
||||||
olmId?: string
|
olmId?: string,
|
||||||
|
version?: string | null
|
||||||
) {
|
) {
|
||||||
if (!olmId) {
|
if (!olmId) {
|
||||||
const [olm] = await db
|
const [olm] = await db
|
||||||
@@ -92,16 +76,21 @@ export async function addPeerData(
|
|||||||
return; // ignore this because an olm might not be associated with the client anymore
|
return; // ignore this because an olm might not be associated with the client anymore
|
||||||
}
|
}
|
||||||
olmId = olm.olmId;
|
olmId = olm.olmId;
|
||||||
|
version = olm.version;
|
||||||
}
|
}
|
||||||
|
|
||||||
await sendToClient(olmId, {
|
await sendToClient(
|
||||||
|
olmId,
|
||||||
|
{
|
||||||
type: `olm/wg/peer/data/add`,
|
type: `olm/wg/peer/data/add`,
|
||||||
data: {
|
data: {
|
||||||
siteId: siteId,
|
siteId: siteId,
|
||||||
remoteSubnets: remoteSubnets,
|
remoteSubnets: remoteSubnets,
|
||||||
aliases: aliases
|
aliases: aliases
|
||||||
}
|
}
|
||||||
}, { incrementConfigVersion: true }).catch((error) => {
|
},
|
||||||
|
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||||
|
).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -111,7 +100,8 @@ export async function removePeerData(
|
|||||||
siteId: number,
|
siteId: number,
|
||||||
remoteSubnets: string[],
|
remoteSubnets: string[],
|
||||||
aliases: Alias[],
|
aliases: Alias[],
|
||||||
olmId?: string
|
olmId?: string,
|
||||||
|
version?: string | null
|
||||||
) {
|
) {
|
||||||
if (!olmId) {
|
if (!olmId) {
|
||||||
const [olm] = await db
|
const [olm] = await db
|
||||||
@@ -123,16 +113,21 @@ export async function removePeerData(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
olmId = olm.olmId;
|
olmId = olm.olmId;
|
||||||
|
version = olm.version;
|
||||||
}
|
}
|
||||||
|
|
||||||
await sendToClient(olmId, {
|
await sendToClient(
|
||||||
|
olmId,
|
||||||
|
{
|
||||||
type: `olm/wg/peer/data/remove`,
|
type: `olm/wg/peer/data/remove`,
|
||||||
data: {
|
data: {
|
||||||
siteId: siteId,
|
siteId: siteId,
|
||||||
remoteSubnets: remoteSubnets,
|
remoteSubnets: remoteSubnets,
|
||||||
aliases: aliases
|
aliases: aliases
|
||||||
}
|
}
|
||||||
}, { incrementConfigVersion: true }).catch((error) => {
|
},
|
||||||
|
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||||
|
).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -152,7 +147,8 @@ export async function updatePeerData(
|
|||||||
newAliases: Alias[];
|
newAliases: Alias[];
|
||||||
}
|
}
|
||||||
| undefined,
|
| undefined,
|
||||||
olmId?: string
|
olmId?: string,
|
||||||
|
version?: string | null
|
||||||
) {
|
) {
|
||||||
if (!olmId) {
|
if (!olmId) {
|
||||||
const [olm] = await db
|
const [olm] = await db
|
||||||
@@ -164,16 +160,21 @@ export async function updatePeerData(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
olmId = olm.olmId;
|
olmId = olm.olmId;
|
||||||
|
version = olm.version;
|
||||||
}
|
}
|
||||||
|
|
||||||
await sendToClient(olmId, {
|
await sendToClient(
|
||||||
|
olmId,
|
||||||
|
{
|
||||||
type: `olm/wg/peer/data/update`,
|
type: `olm/wg/peer/data/update`,
|
||||||
data: {
|
data: {
|
||||||
siteId: siteId,
|
siteId: siteId,
|
||||||
...remoteSubnets,
|
...remoteSubnets,
|
||||||
...aliases
|
...aliases
|
||||||
}
|
}
|
||||||
}, { incrementConfigVersion: true }).catch((error) => {
|
},
|
||||||
|
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||||
|
).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { Request, Response, NextFunction } from "express";
|
import { Request, Response, NextFunction } from "express";
|
||||||
import { eq, and, lt, inArray, sql } from "drizzle-orm";
|
import { eq, sql } from "drizzle-orm";
|
||||||
import { sites } from "@server/db";
|
import { sites } from "@server/db";
|
||||||
import { db } from "@server/db";
|
import { db } from "@server/db";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
@@ -11,19 +11,31 @@ import { FeatureId } from "@server/lib/billing/features";
|
|||||||
import { checkExitNodeOrg } from "#dynamic/lib/exitNodes";
|
import { checkExitNodeOrg } from "#dynamic/lib/exitNodes";
|
||||||
import { build } from "@server/build";
|
import { build } from "@server/build";
|
||||||
|
|
||||||
// Track sites that are already offline to avoid unnecessary queries
|
|
||||||
const offlineSites = new Set<string>();
|
|
||||||
|
|
||||||
// Retry configuration for deadlock handling
|
|
||||||
const MAX_RETRIES = 3;
|
|
||||||
const BASE_DELAY_MS = 50;
|
|
||||||
|
|
||||||
interface PeerBandwidth {
|
interface PeerBandwidth {
|
||||||
publicKey: string;
|
publicKey: string;
|
||||||
bytesIn: number;
|
bytesIn: number;
|
||||||
bytesOut: number;
|
bytesOut: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface AccumulatorEntry {
|
||||||
|
bytesIn: number;
|
||||||
|
bytesOut: number;
|
||||||
|
/** Present when the update came through a remote exit node. */
|
||||||
|
exitNodeId?: number;
|
||||||
|
/** Whether to record egress usage for billing purposes. */
|
||||||
|
calcUsage: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry configuration for deadlock handling
|
||||||
|
const MAX_RETRIES = 3;
|
||||||
|
const BASE_DELAY_MS = 50;
|
||||||
|
|
||||||
|
// How often to flush accumulated bandwidth data to the database
|
||||||
|
const FLUSH_INTERVAL_MS = 30_000; // 30 seconds
|
||||||
|
|
||||||
|
// In-memory accumulator: publicKey -> AccumulatorEntry
|
||||||
|
let accumulator = new Map<string, AccumulatorEntry>();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if an error is a deadlock error
|
* Check if an error is a deadlock error
|
||||||
*/
|
*/
|
||||||
@@ -63,6 +75,220 @@ async function withDeadlockRetry<T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Flush all accumulated site bandwidth data to the database.
|
||||||
|
*
|
||||||
|
* Swaps out the accumulator before writing so that any bandwidth messages
|
||||||
|
* received during the flush are captured in the new accumulator rather than
|
||||||
|
* being lost or causing contention. Entries that fail to write are re-queued
|
||||||
|
* back into the accumulator so they will be retried on the next flush.
|
||||||
|
*
|
||||||
|
* This function is exported so that the application's graceful-shutdown
|
||||||
|
* cleanup handler can call it before the process exits.
|
||||||
|
*/
|
||||||
|
export async function flushSiteBandwidthToDb(): Promise<void> {
|
||||||
|
if (accumulator.size === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Atomically swap out the accumulator so new data keeps flowing in
|
||||||
|
// while we write the snapshot to the database.
|
||||||
|
const snapshot = accumulator;
|
||||||
|
accumulator = new Map<string, AccumulatorEntry>();
|
||||||
|
|
||||||
|
const currentTime = new Date().toISOString();
|
||||||
|
|
||||||
|
// Sort by publicKey for consistent lock ordering across concurrent
|
||||||
|
// writers — deadlock-prevention strategy.
|
||||||
|
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
|
||||||
|
a.localeCompare(b)
|
||||||
|
);
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
`Flushing accumulated bandwidth data for ${sortedEntries.length} site(s) to the database`
|
||||||
|
);
|
||||||
|
|
||||||
|
// Aggregate billing usage by org, collected during the DB update loop.
|
||||||
|
const orgUsageMap = new Map<string, number>();
|
||||||
|
|
||||||
|
for (const [publicKey, { bytesIn, bytesOut, exitNodeId, calcUsage }] of sortedEntries) {
|
||||||
|
try {
|
||||||
|
const updatedSite = await withDeadlockRetry(async () => {
|
||||||
|
const [result] = await db
|
||||||
|
.update(sites)
|
||||||
|
.set({
|
||||||
|
megabytesOut: sql`COALESCE(${sites.megabytesOut}, 0) + ${bytesIn}`,
|
||||||
|
megabytesIn: sql`COALESCE(${sites.megabytesIn}, 0) + ${bytesOut}`,
|
||||||
|
lastBandwidthUpdate: currentTime
|
||||||
|
})
|
||||||
|
.where(eq(sites.pubKey, publicKey))
|
||||||
|
.returning({
|
||||||
|
orgId: sites.orgId,
|
||||||
|
siteId: sites.siteId
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}, `flush bandwidth for site ${publicKey}`);
|
||||||
|
|
||||||
|
if (updatedSite) {
|
||||||
|
if (exitNodeId) {
|
||||||
|
const notAllowed = await checkExitNodeOrg(
|
||||||
|
exitNodeId,
|
||||||
|
updatedSite.orgId
|
||||||
|
);
|
||||||
|
if (notAllowed) {
|
||||||
|
logger.warn(
|
||||||
|
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
|
||||||
|
);
|
||||||
|
// Skip usage tracking for this site but continue
|
||||||
|
// processing the rest.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (calcUsage) {
|
||||||
|
const totalBandwidth = bytesIn + bytesOut;
|
||||||
|
const current = orgUsageMap.get(updatedSite.orgId) ?? 0;
|
||||||
|
orgUsageMap.set(updatedSite.orgId, current + totalBandwidth);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(
|
||||||
|
`Failed to flush bandwidth for site ${publicKey}:`,
|
||||||
|
error
|
||||||
|
);
|
||||||
|
|
||||||
|
// Re-queue the failed entry so it is retried on the next flush
|
||||||
|
// rather than silently dropped.
|
||||||
|
const existing = accumulator.get(publicKey);
|
||||||
|
if (existing) {
|
||||||
|
existing.bytesIn += bytesIn;
|
||||||
|
existing.bytesOut += bytesOut;
|
||||||
|
} else {
|
||||||
|
accumulator.set(publicKey, {
|
||||||
|
bytesIn,
|
||||||
|
bytesOut,
|
||||||
|
exitNodeId,
|
||||||
|
calcUsage
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process billing usage updates outside the site-update loop to keep
|
||||||
|
// lock scope small and concerns separated.
|
||||||
|
if (orgUsageMap.size > 0) {
|
||||||
|
// Sort org IDs for consistent lock ordering.
|
||||||
|
const sortedOrgIds = [...orgUsageMap.keys()].sort();
|
||||||
|
|
||||||
|
for (const orgId of sortedOrgIds) {
|
||||||
|
try {
|
||||||
|
const totalBandwidth = orgUsageMap.get(orgId)!;
|
||||||
|
const bandwidthUsage = await usageService.add(
|
||||||
|
orgId,
|
||||||
|
FeatureId.EGRESS_DATA_MB,
|
||||||
|
totalBandwidth
|
||||||
|
);
|
||||||
|
if (bandwidthUsage) {
|
||||||
|
// Fire-and-forget — don't block the flush on limit checking.
|
||||||
|
usageService
|
||||||
|
.checkLimitSet(
|
||||||
|
orgId,
|
||||||
|
FeatureId.EGRESS_DATA_MB,
|
||||||
|
bandwidthUsage
|
||||||
|
)
|
||||||
|
.catch((error: any) => {
|
||||||
|
logger.error(
|
||||||
|
`Error checking bandwidth limits for org ${orgId}:`,
|
||||||
|
error
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(
|
||||||
|
`Error processing usage for org ${orgId}:`,
|
||||||
|
error
|
||||||
|
);
|
||||||
|
// Continue with other orgs.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Periodic flush timer
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
const flushTimer = setInterval(async () => {
|
||||||
|
try {
|
||||||
|
await flushSiteBandwidthToDb();
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(
|
||||||
|
"Unexpected error during periodic site bandwidth flush:",
|
||||||
|
error
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}, FLUSH_INTERVAL_MS);
|
||||||
|
|
||||||
|
// Allow the process to exit normally even while the timer is pending.
|
||||||
|
// The graceful-shutdown path (see server/cleanup.ts) will call
|
||||||
|
// flushSiteBandwidthToDb() explicitly before process.exit(), so no data
|
||||||
|
// is lost.
|
||||||
|
flushTimer.unref();
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Public API
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Accumulate bandwidth data reported by a gerbil or remote exit node.
|
||||||
|
*
|
||||||
|
* Only peers that actually transferred data (bytesIn > 0) are added to the
|
||||||
|
* accumulator; peers with no activity are silently ignored, which means the
|
||||||
|
* flush will only write rows that have genuinely changed.
|
||||||
|
*
|
||||||
|
* The function is intentionally synchronous in its fast path so that the
|
||||||
|
* HTTP handler can respond immediately without waiting for any I/O.
|
||||||
|
*/
|
||||||
|
export async function updateSiteBandwidth(
|
||||||
|
bandwidthData: PeerBandwidth[],
|
||||||
|
calcUsageAndLimits: boolean,
|
||||||
|
exitNodeId?: number
|
||||||
|
): Promise<void> {
|
||||||
|
for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
|
||||||
|
// Skip peers that haven't transferred any data — writing zeros to the
|
||||||
|
// database would be a no-op anyway.
|
||||||
|
if (bytesIn <= 0 && bytesOut <= 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const existing = accumulator.get(publicKey);
|
||||||
|
if (existing) {
|
||||||
|
existing.bytesIn += bytesIn;
|
||||||
|
existing.bytesOut += bytesOut;
|
||||||
|
// Retain the most-recent exitNodeId for this peer.
|
||||||
|
if (exitNodeId !== undefined) {
|
||||||
|
existing.exitNodeId = exitNodeId;
|
||||||
|
}
|
||||||
|
// Once calcUsage has been requested for a peer, keep it set for
|
||||||
|
// the lifetime of this flush window.
|
||||||
|
if (calcUsageAndLimits) {
|
||||||
|
existing.calcUsage = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
accumulator.set(publicKey, {
|
||||||
|
bytesIn,
|
||||||
|
bytesOut,
|
||||||
|
exitNodeId,
|
||||||
|
calcUsage: calcUsageAndLimits
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// HTTP handler
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
export const receiveBandwidth = async (
|
export const receiveBandwidth = async (
|
||||||
req: Request,
|
req: Request,
|
||||||
res: Response,
|
res: Response,
|
||||||
@@ -75,7 +301,9 @@ export const receiveBandwidth = async (
|
|||||||
throw new Error("Invalid bandwidth data");
|
throw new Error("Invalid bandwidth data");
|
||||||
}
|
}
|
||||||
|
|
||||||
await updateSiteBandwidth(bandwidthData, build == "saas"); // we are checking the usage on saas only
|
// Accumulate in memory; the periodic timer (and the shutdown hook)
|
||||||
|
// will write to the database.
|
||||||
|
await updateSiteBandwidth(bandwidthData, build == "saas");
|
||||||
|
|
||||||
return response(res, {
|
return response(res, {
|
||||||
data: {},
|
data: {},
|
||||||
@@ -94,201 +322,3 @@ export const receiveBandwidth = async (
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
export async function updateSiteBandwidth(
|
|
||||||
bandwidthData: PeerBandwidth[],
|
|
||||||
calcUsageAndLimits: boolean,
|
|
||||||
exitNodeId?: number
|
|
||||||
) {
|
|
||||||
const currentTime = new Date();
|
|
||||||
const oneMinuteAgo = new Date(currentTime.getTime() - 60000); // 1 minute ago
|
|
||||||
|
|
||||||
// Sort bandwidth data by publicKey to ensure consistent lock ordering across all instances
|
|
||||||
// This is critical for preventing deadlocks when multiple instances update the same sites
|
|
||||||
const sortedBandwidthData = [...bandwidthData].sort((a, b) =>
|
|
||||||
a.publicKey.localeCompare(b.publicKey)
|
|
||||||
);
|
|
||||||
|
|
||||||
// First, handle sites that are actively reporting bandwidth
|
|
||||||
const activePeers = sortedBandwidthData.filter((peer) => peer.bytesIn > 0);
|
|
||||||
|
|
||||||
// Aggregate usage data by organization (collected outside transaction)
|
|
||||||
const orgUsageMap = new Map<string, number>();
|
|
||||||
|
|
||||||
if (activePeers.length > 0) {
|
|
||||||
// Remove any active peers from offline tracking since they're sending data
|
|
||||||
activePeers.forEach((peer) => offlineSites.delete(peer.publicKey));
|
|
||||||
|
|
||||||
// Update each active site individually with retry logic
|
|
||||||
// This reduces transaction scope and allows retries per-site
|
|
||||||
for (const peer of activePeers) {
|
|
||||||
try {
|
|
||||||
const updatedSite = await withDeadlockRetry(async () => {
|
|
||||||
const [result] = await db
|
|
||||||
.update(sites)
|
|
||||||
.set({
|
|
||||||
megabytesOut: sql`${sites.megabytesOut} + ${peer.bytesIn}`,
|
|
||||||
megabytesIn: sql`${sites.megabytesIn} + ${peer.bytesOut}`,
|
|
||||||
lastBandwidthUpdate: currentTime.toISOString(),
|
|
||||||
online: true
|
|
||||||
})
|
|
||||||
.where(eq(sites.pubKey, peer.publicKey))
|
|
||||||
.returning({
|
|
||||||
online: sites.online,
|
|
||||||
orgId: sites.orgId,
|
|
||||||
siteId: sites.siteId,
|
|
||||||
lastBandwidthUpdate: sites.lastBandwidthUpdate
|
|
||||||
});
|
|
||||||
return result;
|
|
||||||
}, `update active site ${peer.publicKey}`);
|
|
||||||
|
|
||||||
if (updatedSite) {
|
|
||||||
if (exitNodeId) {
|
|
||||||
const notAllowed = await checkExitNodeOrg(
|
|
||||||
exitNodeId,
|
|
||||||
updatedSite.orgId
|
|
||||||
);
|
|
||||||
if (notAllowed) {
|
|
||||||
logger.warn(
|
|
||||||
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
|
|
||||||
);
|
|
||||||
// Skip this site but continue processing others
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Aggregate bandwidth usage for the org
|
|
||||||
const totalBandwidth = peer.bytesIn + peer.bytesOut;
|
|
||||||
const currentOrgUsage =
|
|
||||||
orgUsageMap.get(updatedSite.orgId) || 0;
|
|
||||||
orgUsageMap.set(
|
|
||||||
updatedSite.orgId,
|
|
||||||
currentOrgUsage + totalBandwidth
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
logger.error(
|
|
||||||
`Failed to update bandwidth for site ${peer.publicKey}:`,
|
|
||||||
error
|
|
||||||
);
|
|
||||||
// Continue with other sites
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process usage updates outside of site update transactions
|
|
||||||
// This separates the concerns and reduces lock contention
|
|
||||||
if (calcUsageAndLimits && orgUsageMap.size > 0) {
|
|
||||||
// Sort org IDs to ensure consistent lock ordering
|
|
||||||
const allOrgIds = [...new Set([...orgUsageMap.keys()])].sort();
|
|
||||||
|
|
||||||
for (const orgId of allOrgIds) {
|
|
||||||
try {
|
|
||||||
// Process bandwidth usage for this org
|
|
||||||
const totalBandwidth = orgUsageMap.get(orgId);
|
|
||||||
if (totalBandwidth) {
|
|
||||||
const bandwidthUsage = await usageService.add(
|
|
||||||
orgId,
|
|
||||||
FeatureId.EGRESS_DATA_MB,
|
|
||||||
totalBandwidth
|
|
||||||
);
|
|
||||||
if (bandwidthUsage) {
|
|
||||||
// Fire and forget - don't block on limit checking
|
|
||||||
usageService
|
|
||||||
.checkLimitSet(
|
|
||||||
orgId,
|
|
||||||
FeatureId.EGRESS_DATA_MB,
|
|
||||||
bandwidthUsage
|
|
||||||
)
|
|
||||||
.catch((error: any) => {
|
|
||||||
logger.error(
|
|
||||||
`Error checking bandwidth limits for org ${orgId}:`,
|
|
||||||
error
|
|
||||||
);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
logger.error(`Error processing usage for org ${orgId}:`, error);
|
|
||||||
// Continue with other orgs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle sites that reported zero bandwidth but need online status updated
|
|
||||||
const zeroBandwidthPeers = sortedBandwidthData.filter(
|
|
||||||
(peer) => peer.bytesIn === 0 && !offlineSites.has(peer.publicKey)
|
|
||||||
);
|
|
||||||
|
|
||||||
if (zeroBandwidthPeers.length > 0) {
|
|
||||||
// Fetch all zero bandwidth sites in one query
|
|
||||||
const zeroBandwidthSites = await db
|
|
||||||
.select()
|
|
||||||
.from(sites)
|
|
||||||
.where(
|
|
||||||
inArray(
|
|
||||||
sites.pubKey,
|
|
||||||
zeroBandwidthPeers.map((p) => p.publicKey)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
// Sort by siteId to ensure consistent lock ordering
|
|
||||||
const sortedZeroBandwidthSites = zeroBandwidthSites.sort(
|
|
||||||
(a, b) => a.siteId - b.siteId
|
|
||||||
);
|
|
||||||
|
|
||||||
for (const site of sortedZeroBandwidthSites) {
|
|
||||||
let newOnlineStatus = site.online;
|
|
||||||
|
|
||||||
// Check if site should go offline based on last bandwidth update WITH DATA
|
|
||||||
if (site.lastBandwidthUpdate) {
|
|
||||||
const lastUpdateWithData = new Date(site.lastBandwidthUpdate);
|
|
||||||
if (lastUpdateWithData < oneMinuteAgo) {
|
|
||||||
newOnlineStatus = false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No previous data update recorded, set to offline
|
|
||||||
newOnlineStatus = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only update online status if it changed
|
|
||||||
if (site.online !== newOnlineStatus) {
|
|
||||||
try {
|
|
||||||
const updatedSite = await withDeadlockRetry(async () => {
|
|
||||||
const [result] = await db
|
|
||||||
.update(sites)
|
|
||||||
.set({
|
|
||||||
online: newOnlineStatus
|
|
||||||
})
|
|
||||||
.where(eq(sites.siteId, site.siteId))
|
|
||||||
.returning();
|
|
||||||
return result;
|
|
||||||
}, `update offline status for site ${site.siteId}`);
|
|
||||||
|
|
||||||
if (updatedSite && exitNodeId) {
|
|
||||||
const notAllowed = await checkExitNodeOrg(
|
|
||||||
exitNodeId,
|
|
||||||
updatedSite.orgId
|
|
||||||
);
|
|
||||||
if (notAllowed) {
|
|
||||||
logger.warn(
|
|
||||||
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If site went offline, add it to our tracking set
|
|
||||||
if (!newOnlineStatus && site.pubKey) {
|
|
||||||
offlineSites.add(site.pubKey);
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
logger.error(
|
|
||||||
`Failed to update offline status for site ${site.siteId}:`,
|
|
||||||
error
|
|
||||||
);
|
|
||||||
// Continue with other sites
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,4 +1,15 @@
|
|||||||
import { clients, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, ExitNode, resources, Site, siteResources, targetHealthCheck, targets } from "@server/db";
|
import {
|
||||||
|
clients,
|
||||||
|
clientSiteResourcesAssociationsCache,
|
||||||
|
clientSitesAssociationsCache,
|
||||||
|
db,
|
||||||
|
ExitNode,
|
||||||
|
resources,
|
||||||
|
Site,
|
||||||
|
siteResources,
|
||||||
|
targetHealthCheck,
|
||||||
|
targets
|
||||||
|
} from "@server/db";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import { initPeerAddHandshake, updatePeer } from "../olm/peers";
|
import { initPeerAddHandshake, updatePeer } from "../olm/peers";
|
||||||
import { eq, and } from "drizzle-orm";
|
import { eq, and } from "drizzle-orm";
|
||||||
@@ -69,6 +80,7 @@ export async function buildClientConfigurationForNewtClient(
|
|||||||
// )
|
// )
|
||||||
// );
|
// );
|
||||||
|
|
||||||
|
if (!client.clientSitesAssociationsCache.isJitMode) { // if we are adding sites through jit then dont add the site to the olm
|
||||||
// update the peer info on the olm
|
// update the peer info on the olm
|
||||||
// if the peer has not been added yet this will be a no-op
|
// if the peer has not been added yet this will be a no-op
|
||||||
await updatePeer(client.clients.clientId, {
|
await updatePeer(client.clients.clientId, {
|
||||||
@@ -103,6 +115,7 @@ export async function buildClientConfigurationForNewtClient(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
publicKey: client.clients.pubKey!,
|
publicKey: client.clients.pubKey!,
|
||||||
@@ -230,9 +243,9 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
|
|||||||
!target.hcInterval ||
|
!target.hcInterval ||
|
||||||
!target.hcMethod
|
!target.hcMethod
|
||||||
) {
|
) {
|
||||||
logger.debug(
|
// logger.debug(
|
||||||
`Skipping adding target health check ${target.targetId} due to missing health check fields`
|
// `Skipping adding target health check ${target.targetId} due to missing health check fields`
|
||||||
);
|
// );
|
||||||
return null; // Skip targets with missing health check fields
|
return null; // Skip targets with missing health check fields
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import { db, ExitNode, exitNodes, Newt, sites } from "@server/db";
|
|||||||
import { eq } from "drizzle-orm";
|
import { eq } from "drizzle-orm";
|
||||||
import { sendToExitNode } from "#dynamic/lib/exitNodes";
|
import { sendToExitNode } from "#dynamic/lib/exitNodes";
|
||||||
import { buildClientConfigurationForNewtClient } from "./buildConfiguration";
|
import { buildClientConfigurationForNewtClient } from "./buildConfiguration";
|
||||||
|
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||||
|
|
||||||
const inputSchema = z.object({
|
const inputSchema = z.object({
|
||||||
publicKey: z.string(),
|
publicKey: z.string(),
|
||||||
@@ -135,6 +136,9 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
|
|||||||
targets
|
targets
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
options: {
|
||||||
|
compress: canCompress(newt.version, "newt")
|
||||||
|
},
|
||||||
broadcast: false,
|
broadcast: false,
|
||||||
excludeSender: false
|
excludeSender: false
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,105 +1,107 @@
|
|||||||
import { db, sites } from "@server/db";
|
import { db, newts, sites } from "@server/db";
|
||||||
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
|
import { hasActiveConnections, getClientConfigVersion } from "#dynamic/routers/ws";
|
||||||
import { MessageHandler } from "@server/routers/ws";
|
import { MessageHandler } from "@server/routers/ws";
|
||||||
import { clients, Newt } from "@server/db";
|
import { Newt } from "@server/db";
|
||||||
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
|
||||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
|
||||||
import { sendTerminateClient } from "../client/terminate";
|
|
||||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
|
||||||
import { sha256 } from "@oslojs/crypto/sha2";
|
|
||||||
import { sendNewtSyncMessage } from "./sync";
|
import { sendNewtSyncMessage } from "./sync";
|
||||||
|
|
||||||
// Track if the offline checker interval is running
|
// Track if the offline checker interval is running
|
||||||
// let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
||||||
// const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
|
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
|
||||||
// const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
|
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Starts the background interval that checks for clients that haven't pinged recently
|
* Starts the background interval that checks for newt sites that haven't
|
||||||
* and marks them as offline
|
* pinged recently and marks them as offline. For backward compatibility,
|
||||||
|
* a site is only marked offline when there is no active WebSocket connection
|
||||||
|
* either — so older newt versions that don't send pings but remain connected
|
||||||
|
* continue to be treated as online.
|
||||||
*/
|
*/
|
||||||
// export const startNewtOfflineChecker = (): void => {
|
export const startNewtOfflineChecker = (): void => {
|
||||||
// if (offlineCheckerInterval) {
|
if (offlineCheckerInterval) {
|
||||||
// return; // Already running
|
return; // Already running
|
||||||
// }
|
}
|
||||||
|
|
||||||
// offlineCheckerInterval = setInterval(async () => {
|
offlineCheckerInterval = setInterval(async () => {
|
||||||
// try {
|
try {
|
||||||
// const twoMinutesAgo = Math.floor(
|
const twoMinutesAgo = Math.floor(
|
||||||
// (Date.now() - OFFLINE_THRESHOLD_MS) / 1000
|
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
|
||||||
// );
|
);
|
||||||
|
|
||||||
// // TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING
|
// Find all online newt-type sites that haven't pinged recently
|
||||||
|
// (or have never pinged at all). Join newts to obtain the newtId
|
||||||
|
// needed for the WebSocket connection check.
|
||||||
|
const staleSites = await db
|
||||||
|
.select({
|
||||||
|
siteId: sites.siteId,
|
||||||
|
newtId: newts.newtId,
|
||||||
|
lastPing: sites.lastPing
|
||||||
|
})
|
||||||
|
.from(sites)
|
||||||
|
.innerJoin(newts, eq(newts.siteId, sites.siteId))
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
eq(sites.online, true),
|
||||||
|
eq(sites.type, "newt"),
|
||||||
|
or(
|
||||||
|
lt(sites.lastPing, twoMinutesAgo),
|
||||||
|
isNull(sites.lastPing)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
// // Find clients that haven't pinged in the last 2 minutes and mark them as offline
|
for (const staleSite of staleSites) {
|
||||||
// const offlineClients = await db
|
// Backward-compatibility check: if the newt still has an
|
||||||
// .update(clients)
|
// active WebSocket connection (older clients that don't send
|
||||||
// .set({ online: false })
|
// pings), keep the site online.
|
||||||
// .where(
|
const isConnected = await hasActiveConnections(staleSite.newtId);
|
||||||
// and(
|
if (isConnected) {
|
||||||
// eq(clients.online, true),
|
logger.debug(
|
||||||
// or(
|
`Newt ${staleSite.newtId} has not pinged recently but is still connected via WebSocket — keeping site ${staleSite.siteId} online`
|
||||||
// lt(clients.lastPing, twoMinutesAgo),
|
);
|
||||||
// isNull(clients.lastPing)
|
continue;
|
||||||
// )
|
}
|
||||||
// )
|
|
||||||
// )
|
|
||||||
// .returning();
|
|
||||||
|
|
||||||
// for (const offlineClient of offlineClients) {
|
logger.info(
|
||||||
// logger.info(
|
`Marking site ${staleSite.siteId} offline: newt ${staleSite.newtId} has no recent ping and no active WebSocket connection`
|
||||||
// `Kicking offline newt client ${offlineClient.clientId} due to inactivity`
|
);
|
||||||
// );
|
|
||||||
|
|
||||||
// if (!offlineClient.newtId) {
|
await db
|
||||||
// logger.warn(
|
.update(sites)
|
||||||
// `Offline client ${offlineClient.clientId} has no newtId, cannot disconnect`
|
.set({ online: false })
|
||||||
// );
|
.where(eq(sites.siteId, staleSite.siteId));
|
||||||
// continue;
|
}
|
||||||
// }
|
} catch (error) {
|
||||||
|
logger.error("Error in newt offline checker interval", { error });
|
||||||
|
}
|
||||||
|
}, OFFLINE_CHECK_INTERVAL);
|
||||||
|
|
||||||
// // Send a disconnect message to the client if connected
|
logger.debug("Started newt offline checker interval");
|
||||||
// try {
|
};
|
||||||
// await sendTerminateClient(
|
|
||||||
// offlineClient.clientId,
|
|
||||||
// offlineClient.newtId
|
|
||||||
// ); // terminate first
|
|
||||||
// // wait a moment to ensure the message is sent
|
|
||||||
// await new Promise((resolve) => setTimeout(resolve, 1000));
|
|
||||||
// await disconnectClient(offlineClient.newtId);
|
|
||||||
// } catch (error) {
|
|
||||||
// logger.error(
|
|
||||||
// `Error sending disconnect to offline newt ${offlineClient.clientId}`,
|
|
||||||
// { error }
|
|
||||||
// );
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// } catch (error) {
|
|
||||||
// logger.error("Error in offline checker interval", { error });
|
|
||||||
// }
|
|
||||||
// }, OFFLINE_CHECK_INTERVAL);
|
|
||||||
|
|
||||||
// logger.debug("Started offline checker interval");
|
|
||||||
// };
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stops the background interval that checks for offline clients
|
* Stops the background interval that checks for offline newt sites.
|
||||||
*/
|
*/
|
||||||
// export const stopNewtOfflineChecker = (): void => {
|
export const stopNewtOfflineChecker = (): void => {
|
||||||
// if (offlineCheckerInterval) {
|
if (offlineCheckerInterval) {
|
||||||
// clearInterval(offlineCheckerInterval);
|
clearInterval(offlineCheckerInterval);
|
||||||
// offlineCheckerInterval = null;
|
offlineCheckerInterval = null;
|
||||||
// logger.info("Stopped offline checker interval");
|
logger.info("Stopped newt offline checker interval");
|
||||||
// }
|
}
|
||||||
// };
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Handles ping messages from clients and responds with pong
|
* Handles ping messages from newt clients.
|
||||||
|
*
|
||||||
|
* On each ping:
|
||||||
|
* - Marks the associated site as online.
|
||||||
|
* - Records the current timestamp as the newt's last-ping time.
|
||||||
|
* - Triggers a config sync if the newt is running an outdated config version.
|
||||||
|
* - Responds with a pong message.
|
||||||
*/
|
*/
|
||||||
export const handleNewtPingMessage: MessageHandler = async (context) => {
|
export const handleNewtPingMessage: MessageHandler = async (context) => {
|
||||||
const { message, client: c, sendToClient } = context;
|
const { message, client: c } = context;
|
||||||
const newt = c as Newt;
|
const newt = c as Newt;
|
||||||
|
|
||||||
if (!newt) {
|
if (!newt) {
|
||||||
@@ -112,15 +114,31 @@ export const handleNewtPingMessage: MessageHandler = async (context) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the version
|
try {
|
||||||
|
// Mark the site as online and record the ping timestamp.
|
||||||
|
await db
|
||||||
|
.update(sites)
|
||||||
|
.set({
|
||||||
|
online: true,
|
||||||
|
lastPing: Math.floor(Date.now() / 1000)
|
||||||
|
})
|
||||||
|
.where(eq(sites.siteId, newt.siteId));
|
||||||
|
} catch (error) {
|
||||||
|
logger.error("Error updating online state on newt ping", { error });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check config version and sync if stale.
|
||||||
const configVersion = await getClientConfigVersion(newt.newtId);
|
const configVersion = await getClientConfigVersion(newt.newtId);
|
||||||
|
|
||||||
if (message.configVersion && configVersion != null && configVersion != message.configVersion) {
|
if (
|
||||||
|
message.configVersion != null &&
|
||||||
|
configVersion != null &&
|
||||||
|
configVersion !== message.configVersion
|
||||||
|
) {
|
||||||
logger.warn(
|
logger.warn(
|
||||||
`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
|
`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
|
||||||
);
|
);
|
||||||
|
|
||||||
// get the site
|
|
||||||
const [site] = await db
|
const [site] = await db
|
||||||
.select()
|
.select()
|
||||||
.from(sites)
|
.from(sites)
|
||||||
@@ -137,19 +155,6 @@ export const handleNewtPingMessage: MessageHandler = async (context) => {
|
|||||||
await sendNewtSyncMessage(newt, site);
|
await sendNewtSyncMessage(newt, site);
|
||||||
}
|
}
|
||||||
|
|
||||||
// try {
|
|
||||||
// // Update the client's last ping timestamp
|
|
||||||
// await db
|
|
||||||
// .update(clients)
|
|
||||||
// .set({
|
|
||||||
// lastPing: Math.floor(Date.now() / 1000),
|
|
||||||
// online: true
|
|
||||||
// })
|
|
||||||
// .where(eq(clients.clientId, newt.clientId));
|
|
||||||
// } catch (error) {
|
|
||||||
// logger.error("Error handling ping message", { error });
|
|
||||||
// }
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
message: {
|
message: {
|
||||||
type: "pong",
|
type: "pong",
|
||||||
|
|||||||
@@ -5,9 +5,7 @@ import { eq } from "drizzle-orm";
|
|||||||
import { addPeer, deletePeer } from "../gerbil/peers";
|
import { addPeer, deletePeer } from "../gerbil/peers";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import config from "@server/lib/config";
|
import config from "@server/lib/config";
|
||||||
import {
|
import { findNextAvailableCidr } from "@server/lib/ip";
|
||||||
findNextAvailableCidr,
|
|
||||||
} from "@server/lib/ip";
|
|
||||||
import {
|
import {
|
||||||
selectBestExitNode,
|
selectBestExitNode,
|
||||||
verifyExitNodeOrgAccess
|
verifyExitNodeOrgAccess
|
||||||
@@ -15,6 +13,7 @@ import {
|
|||||||
import { fetchContainers } from "./dockerSocket";
|
import { fetchContainers } from "./dockerSocket";
|
||||||
import { lockManager } from "#dynamic/lib/lock";
|
import { lockManager } from "#dynamic/lib/lock";
|
||||||
import { buildTargetConfigurationForNewtClient } from "./buildConfiguration";
|
import { buildTargetConfigurationForNewtClient } from "./buildConfiguration";
|
||||||
|
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||||
|
|
||||||
export type ExitNodePingResult = {
|
export type ExitNodePingResult = {
|
||||||
exitNodeId: number;
|
exitNodeId: number;
|
||||||
@@ -215,6 +214,9 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
|||||||
healthCheckTargets: validHealthCheckTargets
|
healthCheckTargets: validHealthCheckTargets
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
options: {
|
||||||
|
compress: canCompress(newt.version, "newt")
|
||||||
|
},
|
||||||
broadcast: false, // Send to all clients
|
broadcast: false, // Send to all clients
|
||||||
excludeSender: false // Include sender in broadcast
|
excludeSender: false // Include sender in broadcast
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -10,10 +10,21 @@ interface PeerBandwidth {
|
|||||||
bytesOut: number;
|
bytesOut: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface BandwidthAccumulator {
|
||||||
|
bytesIn: number;
|
||||||
|
bytesOut: number;
|
||||||
|
}
|
||||||
|
|
||||||
// Retry configuration for deadlock handling
|
// Retry configuration for deadlock handling
|
||||||
const MAX_RETRIES = 3;
|
const MAX_RETRIES = 3;
|
||||||
const BASE_DELAY_MS = 50;
|
const BASE_DELAY_MS = 50;
|
||||||
|
|
||||||
|
// How often to flush accumulated bandwidth data to the database
|
||||||
|
const FLUSH_INTERVAL_MS = 120_000; // 120 seconds
|
||||||
|
|
||||||
|
// In-memory accumulator: publicKey -> { bytesIn, bytesOut }
|
||||||
|
let accumulator = new Map<string, BandwidthAccumulator>();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check if an error is a deadlock error
|
* Check if an error is a deadlock error
|
||||||
*/
|
*/
|
||||||
@@ -53,6 +64,90 @@ async function withDeadlockRetry<T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Flush all accumulated bandwidth data to the database.
|
||||||
|
*
|
||||||
|
* Swaps out the accumulator before writing so that any bandwidth messages
|
||||||
|
* received during the flush are captured in the new accumulator rather than
|
||||||
|
* being lost or causing contention. Entries that fail to write are re-queued
|
||||||
|
* back into the accumulator so they will be retried on the next flush.
|
||||||
|
*
|
||||||
|
* This function is exported so that the application's graceful-shutdown
|
||||||
|
* cleanup handler can call it before the process exits.
|
||||||
|
*/
|
||||||
|
export async function flushBandwidthToDb(): Promise<void> {
|
||||||
|
if (accumulator.size === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Atomically swap out the accumulator so new data keeps flowing in
|
||||||
|
// while we write the snapshot to the database.
|
||||||
|
const snapshot = accumulator;
|
||||||
|
accumulator = new Map<string, BandwidthAccumulator>();
|
||||||
|
|
||||||
|
const currentTime = new Date().toISOString();
|
||||||
|
|
||||||
|
// Sort by publicKey for consistent lock ordering across concurrent
|
||||||
|
// writers — this is the same deadlock-prevention strategy used in the
|
||||||
|
// original per-message implementation.
|
||||||
|
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
|
||||||
|
a.localeCompare(b)
|
||||||
|
);
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
`Flushing accumulated bandwidth data for ${sortedEntries.length} client(s) to the database`
|
||||||
|
);
|
||||||
|
|
||||||
|
for (const [publicKey, { bytesIn, bytesOut }] of sortedEntries) {
|
||||||
|
try {
|
||||||
|
await withDeadlockRetry(async () => {
|
||||||
|
// Use atomic SQL increment to avoid the SELECT-then-UPDATE
|
||||||
|
// anti-pattern and the races it would introduce.
|
||||||
|
await db
|
||||||
|
.update(clients)
|
||||||
|
.set({
|
||||||
|
// Note: bytesIn from peer goes to megabytesOut (data
|
||||||
|
// sent to client) and bytesOut from peer goes to
|
||||||
|
// megabytesIn (data received from client).
|
||||||
|
megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`,
|
||||||
|
megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`,
|
||||||
|
lastBandwidthUpdate: currentTime
|
||||||
|
})
|
||||||
|
.where(eq(clients.pubKey, publicKey));
|
||||||
|
}, `flush bandwidth for client ${publicKey}`);
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(
|
||||||
|
`Failed to flush bandwidth for client ${publicKey}:`,
|
||||||
|
error
|
||||||
|
);
|
||||||
|
|
||||||
|
// Re-queue the failed entry so it is retried on the next flush
|
||||||
|
// rather than silently dropped.
|
||||||
|
const existing = accumulator.get(publicKey);
|
||||||
|
if (existing) {
|
||||||
|
existing.bytesIn += bytesIn;
|
||||||
|
existing.bytesOut += bytesOut;
|
||||||
|
} else {
|
||||||
|
accumulator.set(publicKey, { bytesIn, bytesOut });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const flushTimer = setInterval(async () => {
|
||||||
|
try {
|
||||||
|
await flushBandwidthToDb();
|
||||||
|
} catch (error) {
|
||||||
|
logger.error("Unexpected error during periodic bandwidth flush:", error);
|
||||||
|
}
|
||||||
|
}, FLUSH_INTERVAL_MS);
|
||||||
|
|
||||||
|
// Calling unref() means this timer will not keep the Node.js event loop alive
|
||||||
|
// on its own — the process can still exit normally when there is no other work
|
||||||
|
// left. The graceful-shutdown path (see server/cleanup.ts) will call
|
||||||
|
// flushBandwidthToDb() explicitly before process.exit(), so no data is lost.
|
||||||
|
flushTimer.unref();
|
||||||
|
|
||||||
export const handleReceiveBandwidthMessage: MessageHandler = async (
|
export const handleReceiveBandwidthMessage: MessageHandler = async (
|
||||||
context
|
context
|
||||||
) => {
|
) => {
|
||||||
@@ -69,40 +164,21 @@ export const handleReceiveBandwidthMessage: MessageHandler = async (
|
|||||||
throw new Error("Invalid bandwidth data");
|
throw new Error("Invalid bandwidth data");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort bandwidth data by publicKey to ensure consistent lock ordering across all instances
|
// Accumulate the incoming data in memory; the periodic timer (and the
|
||||||
// This is critical for preventing deadlocks when multiple instances update the same clients
|
// shutdown hook) will take care of writing it to the database.
|
||||||
const sortedBandwidthData = [...bandwidthData].sort((a, b) =>
|
for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
|
||||||
a.publicKey.localeCompare(b.publicKey)
|
// Skip peers that haven't transferred any data — writing zeros to the
|
||||||
);
|
// database would be a no-op anyway.
|
||||||
|
if (bytesIn <= 0 && bytesOut <= 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
const currentTime = new Date().toISOString();
|
const existing = accumulator.get(publicKey);
|
||||||
|
if (existing) {
|
||||||
// Update each client individually with retry logic
|
existing.bytesIn += bytesIn;
|
||||||
// This reduces transaction scope and allows retries per-client
|
existing.bytesOut += bytesOut;
|
||||||
for (const peer of sortedBandwidthData) {
|
} else {
|
||||||
const { publicKey, bytesIn, bytesOut } = peer;
|
accumulator.set(publicKey, { bytesIn, bytesOut });
|
||||||
|
|
||||||
try {
|
|
||||||
await withDeadlockRetry(async () => {
|
|
||||||
// Use atomic SQL increment to avoid SELECT then UPDATE pattern
|
|
||||||
// This eliminates the need to read the current value first
|
|
||||||
await db
|
|
||||||
.update(clients)
|
|
||||||
.set({
|
|
||||||
// Note: bytesIn from peer goes to megabytesOut (data sent to client)
|
|
||||||
// and bytesOut from peer goes to megabytesIn (data received from client)
|
|
||||||
megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`,
|
|
||||||
megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`,
|
|
||||||
lastBandwidthUpdate: currentTime
|
|
||||||
})
|
|
||||||
.where(eq(clients.pubKey, publicKey));
|
|
||||||
}, `update client bandwidth ${publicKey}`);
|
|
||||||
} catch (error) {
|
|
||||||
logger.error(
|
|
||||||
`Failed to update bandwidth for client ${publicKey}:`,
|
|
||||||
error
|
|
||||||
);
|
|
||||||
// Continue with other clients even if one fails
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import {
|
|||||||
buildClientConfigurationForNewtClient,
|
buildClientConfigurationForNewtClient,
|
||||||
buildTargetConfigurationForNewtClient
|
buildTargetConfigurationForNewtClient
|
||||||
} from "./buildConfiguration";
|
} from "./buildConfiguration";
|
||||||
|
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||||
|
|
||||||
export async function sendNewtSyncMessage(newt: Newt, site: Site) {
|
export async function sendNewtSyncMessage(newt: Newt, site: Site) {
|
||||||
const { tcpTargets, udpTargets, validHealthCheckTargets } =
|
const { tcpTargets, udpTargets, validHealthCheckTargets } =
|
||||||
@@ -24,7 +25,9 @@ export async function sendNewtSyncMessage(newt: Newt, site: Site) {
|
|||||||
exitNode
|
exitNode
|
||||||
);
|
);
|
||||||
|
|
||||||
await sendToClient(newt.newtId, {
|
await sendToClient(
|
||||||
|
newt.newtId,
|
||||||
|
{
|
||||||
type: "newt/sync",
|
type: "newt/sync",
|
||||||
data: {
|
data: {
|
||||||
proxyTargets: {
|
proxyTargets: {
|
||||||
@@ -35,7 +38,11 @@ export async function sendNewtSyncMessage(newt: Newt, site: Site) {
|
|||||||
peers: peers,
|
peers: peers,
|
||||||
clientTargets: targets
|
clientTargets: targets
|
||||||
}
|
}
|
||||||
}).catch((error) => {
|
},
|
||||||
|
{
|
||||||
|
compress: canCompress(newt.version, "newt")
|
||||||
|
}
|
||||||
|
).catch((error) => {
|
||||||
logger.warn(`Error sending newt sync message:`, error);
|
logger.warn(`Error sending newt sync message:`, error);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,13 +2,14 @@ import { Target, TargetHealthCheck, db, targetHealthCheck } from "@server/db";
|
|||||||
import { sendToClient } from "#dynamic/routers/ws";
|
import { sendToClient } from "#dynamic/routers/ws";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import { eq, inArray } from "drizzle-orm";
|
import { eq, inArray } from "drizzle-orm";
|
||||||
|
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||||
|
|
||||||
export async function addTargets(
|
export async function addTargets(
|
||||||
newtId: string,
|
newtId: string,
|
||||||
targets: Target[],
|
targets: Target[],
|
||||||
healthCheckData: TargetHealthCheck[],
|
healthCheckData: TargetHealthCheck[],
|
||||||
protocol: string,
|
protocol: string,
|
||||||
port: number | null = null
|
version?: string | null
|
||||||
) {
|
) {
|
||||||
//create a list of udp and tcp targets
|
//create a list of udp and tcp targets
|
||||||
const payloadTargets = targets.map((target) => {
|
const payloadTargets = targets.map((target) => {
|
||||||
@@ -22,7 +23,7 @@ export async function addTargets(
|
|||||||
data: {
|
data: {
|
||||||
targets: payloadTargets
|
targets: payloadTargets
|
||||||
}
|
}
|
||||||
}, { incrementConfigVersion: true });
|
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
|
||||||
|
|
||||||
// Create a map for quick lookup
|
// Create a map for quick lookup
|
||||||
const healthCheckMap = new Map<number, TargetHealthCheck>();
|
const healthCheckMap = new Map<number, TargetHealthCheck>();
|
||||||
@@ -103,14 +104,14 @@ export async function addTargets(
|
|||||||
data: {
|
data: {
|
||||||
targets: validHealthCheckTargets
|
targets: validHealthCheckTargets
|
||||||
}
|
}
|
||||||
}, { incrementConfigVersion: true });
|
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function removeTargets(
|
export async function removeTargets(
|
||||||
newtId: string,
|
newtId: string,
|
||||||
targets: Target[],
|
targets: Target[],
|
||||||
protocol: string,
|
protocol: string,
|
||||||
port: number | null = null
|
version?: string | null
|
||||||
) {
|
) {
|
||||||
//create a list of udp and tcp targets
|
//create a list of udp and tcp targets
|
||||||
const payloadTargets = targets.map((target) => {
|
const payloadTargets = targets.map((target) => {
|
||||||
@@ -135,5 +136,5 @@ export async function removeTargets(
|
|||||||
data: {
|
data: {
|
||||||
ids: healthCheckTargets
|
ids: healthCheckTargets
|
||||||
}
|
}
|
||||||
}, { incrementConfigVersion: true });
|
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,17 @@
|
|||||||
import { Client, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, exitNodes, siteResources, sites } from "@server/db";
|
import {
|
||||||
import { generateAliasConfig, generateRemoteSubnets } from "@server/lib/ip";
|
Client,
|
||||||
|
clientSiteResourcesAssociationsCache,
|
||||||
|
clientSitesAssociationsCache,
|
||||||
|
db,
|
||||||
|
exitNodes,
|
||||||
|
siteResources,
|
||||||
|
sites
|
||||||
|
} from "@server/db";
|
||||||
|
import {
|
||||||
|
Alias,
|
||||||
|
generateAliasConfig,
|
||||||
|
generateRemoteSubnets
|
||||||
|
} from "@server/lib/ip";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import { and, eq } from "drizzle-orm";
|
import { and, eq } from "drizzle-orm";
|
||||||
import { addPeer, deletePeer } from "../newt/peers";
|
import { addPeer, deletePeer } from "../newt/peers";
|
||||||
@@ -8,9 +20,19 @@ import config from "@server/lib/config";
|
|||||||
export async function buildSiteConfigurationForOlmClient(
|
export async function buildSiteConfigurationForOlmClient(
|
||||||
client: Client,
|
client: Client,
|
||||||
publicKey: string | null,
|
publicKey: string | null,
|
||||||
relay: boolean
|
relay: boolean,
|
||||||
|
jitMode: boolean = false
|
||||||
) {
|
) {
|
||||||
const siteConfigurations = [];
|
const siteConfigurations: {
|
||||||
|
siteId: number;
|
||||||
|
name?: string
|
||||||
|
endpoint?: string
|
||||||
|
publicKey?: string
|
||||||
|
serverIP?: string | null
|
||||||
|
serverPort?: number | null
|
||||||
|
remoteSubnets?: string[];
|
||||||
|
aliases: Alias[];
|
||||||
|
}[] = [];
|
||||||
|
|
||||||
// Get all sites data
|
// Get all sites data
|
||||||
const sitesData = await db
|
const sitesData = await db
|
||||||
@@ -27,6 +49,40 @@ export async function buildSiteConfigurationForOlmClient(
|
|||||||
sites: site,
|
sites: site,
|
||||||
clientSitesAssociationsCache: association
|
clientSitesAssociationsCache: association
|
||||||
} of sitesData) {
|
} of sitesData) {
|
||||||
|
const allSiteResources = await db // only get the site resources that this client has access to
|
||||||
|
.select()
|
||||||
|
.from(siteResources)
|
||||||
|
.innerJoin(
|
||||||
|
clientSiteResourcesAssociationsCache,
|
||||||
|
eq(
|
||||||
|
siteResources.siteResourceId,
|
||||||
|
clientSiteResourcesAssociationsCache.siteResourceId
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
eq(siteResources.siteId, site.siteId),
|
||||||
|
eq(
|
||||||
|
clientSiteResourcesAssociationsCache.clientId,
|
||||||
|
client.clientId
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
if (jitMode) {
|
||||||
|
// Add site configuration to the array
|
||||||
|
siteConfigurations.push({
|
||||||
|
siteId: site.siteId,
|
||||||
|
// remoteSubnets: generateRemoteSubnets(
|
||||||
|
// allSiteResources.map(({ siteResources }) => siteResources)
|
||||||
|
// ),
|
||||||
|
aliases: generateAliasConfig(
|
||||||
|
allSiteResources.map(({ siteResources }) => siteResources)
|
||||||
|
)
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (!site.exitNodeId) {
|
if (!site.exitNodeId) {
|
||||||
logger.warn(
|
logger.warn(
|
||||||
`Site ${site.siteId} does not have exit node, skipping`
|
`Site ${site.siteId} does not have exit node, skipping`
|
||||||
@@ -110,26 +166,6 @@ export async function buildSiteConfigurationForOlmClient(
|
|||||||
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
|
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
const allSiteResources = await db // only get the site resources that this client has access to
|
|
||||||
.select()
|
|
||||||
.from(siteResources)
|
|
||||||
.innerJoin(
|
|
||||||
clientSiteResourcesAssociationsCache,
|
|
||||||
eq(
|
|
||||||
siteResources.siteResourceId,
|
|
||||||
clientSiteResourcesAssociationsCache.siteResourceId
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.where(
|
|
||||||
and(
|
|
||||||
eq(siteResources.siteId, site.siteId),
|
|
||||||
eq(
|
|
||||||
clientSiteResourcesAssociationsCache.clientId,
|
|
||||||
client.clientId
|
|
||||||
)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
// Add site configuration to the array
|
// Add site configuration to the array
|
||||||
siteConfigurations.push({
|
siteConfigurations.push({
|
||||||
siteId: site.siteId,
|
siteId: site.siteId,
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ import { getUserDeviceName } from "@server/db/names";
|
|||||||
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
|
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
|
||||||
import { OlmErrorCodes, sendOlmError } from "./error";
|
import { OlmErrorCodes, sendOlmError } from "./error";
|
||||||
import { handleFingerprintInsertion } from "./fingerprintingUtils";
|
import { handleFingerprintInsertion } from "./fingerprintingUtils";
|
||||||
|
import { Alias } from "@server/lib/ip";
|
||||||
|
import { build } from "@server/build";
|
||||||
|
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||||
|
|
||||||
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||||
logger.info("Handling register olm message!");
|
logger.info("Handling register olm message!");
|
||||||
@@ -207,6 +210,32 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get all sites data
|
||||||
|
const sitesCountResult = await db
|
||||||
|
.select({ count: count() })
|
||||||
|
.from(sites)
|
||||||
|
.innerJoin(
|
||||||
|
clientSitesAssociationsCache,
|
||||||
|
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
||||||
|
)
|
||||||
|
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||||
|
|
||||||
|
// Extract the count value from the result array
|
||||||
|
const sitesCount =
|
||||||
|
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
|
||||||
|
|
||||||
|
// Prepare an array to store site configurations
|
||||||
|
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
|
||||||
|
|
||||||
|
let jitMode = true;
|
||||||
|
if (sitesCount > 250 && build == "saas") {
|
||||||
|
// THIS IS THE MAX ON THE BUSINESS TIER
|
||||||
|
// we have too many sites
|
||||||
|
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
|
||||||
|
logger.info("Too many sites (%d), dropping into JIT mode", sitesCount);
|
||||||
|
jitMode = true;
|
||||||
|
}
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
|
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
|
||||||
);
|
);
|
||||||
@@ -233,28 +262,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
|||||||
await db
|
await db
|
||||||
.update(clientSitesAssociationsCache)
|
.update(clientSitesAssociationsCache)
|
||||||
.set({
|
.set({
|
||||||
isRelayed: relay == true
|
isRelayed: relay == true,
|
||||||
|
isJitMode: jitMode
|
||||||
})
|
})
|
||||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get all sites data
|
|
||||||
const sitesCountResult = await db
|
|
||||||
.select({ count: count() })
|
|
||||||
.from(sites)
|
|
||||||
.innerJoin(
|
|
||||||
clientSitesAssociationsCache,
|
|
||||||
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
|
||||||
)
|
|
||||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
|
||||||
|
|
||||||
// Extract the count value from the result array
|
|
||||||
const sitesCount =
|
|
||||||
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
|
|
||||||
|
|
||||||
// Prepare an array to store site configurations
|
|
||||||
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
|
|
||||||
|
|
||||||
// this prevents us from accepting a register from an olm that has not hole punched yet.
|
// this prevents us from accepting a register from an olm that has not hole punched yet.
|
||||||
// the olm will pump the register so we can keep checking
|
// the olm will pump the register so we can keep checking
|
||||||
// TODO: I still think there is a better way to do this rather than locking it out here but ???
|
// TODO: I still think there is a better way to do this rather than locking it out here but ???
|
||||||
@@ -269,15 +282,10 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
|||||||
const siteConfigurations = await buildSiteConfigurationForOlmClient(
|
const siteConfigurations = await buildSiteConfigurationForOlmClient(
|
||||||
client,
|
client,
|
||||||
publicKey,
|
publicKey,
|
||||||
relay
|
relay,
|
||||||
|
jitMode
|
||||||
);
|
);
|
||||||
|
|
||||||
// REMOVED THIS SO IT CREATES THE INTERFACE AND JUST WAITS FOR THE SITES
|
|
||||||
// if (siteConfigurations.length === 0) {
|
|
||||||
// logger.warn("No valid site configurations found");
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Return connect message with all site configurations
|
// Return connect message with all site configurations
|
||||||
return {
|
return {
|
||||||
message: {
|
message: {
|
||||||
@@ -288,6 +296,9 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
|||||||
utilitySubnet: org.utilitySubnet
|
utilitySubnet: org.utilitySubnet
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
options: {
|
||||||
|
compress: canCompress(olm.version, "olm")
|
||||||
|
},
|
||||||
broadcast: false,
|
broadcast: false,
|
||||||
excludeSender: false
|
excludeSender: false
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!olm.clientId) {
|
if (!olm.clientId) {
|
||||||
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
|
logger.warn("Olm has no client!");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { siteId } = message.data;
|
const { siteId, chainId } = message.data;
|
||||||
|
|
||||||
// Get the site
|
// Get the site
|
||||||
const [site] = await db
|
const [site] = await db
|
||||||
@@ -90,7 +90,8 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
|
|||||||
data: {
|
data: {
|
||||||
siteId: siteId,
|
siteId: siteId,
|
||||||
relayEndpoint: exitNode.endpoint,
|
relayEndpoint: exitNode.endpoint,
|
||||||
relayPort: config.getRawConfig().gerbil.clients_start_port
|
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
||||||
|
chainId
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
broadcast: false,
|
broadcast: false,
|
||||||
|
|||||||
241
server/routers/olm/handleOlmServerInitAddPeerHandshake.ts
Normal file
241
server/routers/olm/handleOlmServerInitAddPeerHandshake.ts
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
import {
|
||||||
|
clientSiteResourcesAssociationsCache,
|
||||||
|
clientSitesAssociationsCache,
|
||||||
|
db,
|
||||||
|
exitNodes,
|
||||||
|
Site,
|
||||||
|
siteResources
|
||||||
|
} from "@server/db";
|
||||||
|
import { MessageHandler } from "@server/routers/ws";
|
||||||
|
import { clients, Olm, sites } from "@server/db";
|
||||||
|
import { and, eq, or } from "drizzle-orm";
|
||||||
|
import logger from "@server/logger";
|
||||||
|
import { initPeerAddHandshake } from "./peers";
|
||||||
|
|
||||||
|
export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
|
||||||
|
context
|
||||||
|
) => {
|
||||||
|
logger.info("Handling register olm message!");
|
||||||
|
const { message, client: c, sendToClient } = context;
|
||||||
|
const olm = c as Olm;
|
||||||
|
|
||||||
|
if (!olm) {
|
||||||
|
logger.warn("Olm not found");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!olm.clientId) {
|
||||||
|
logger.warn("Olm has no client!"); // TODO: Maybe we create the site here?
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const clientId = olm.clientId;
|
||||||
|
|
||||||
|
const [client] = await db
|
||||||
|
.select()
|
||||||
|
.from(clients)
|
||||||
|
.where(eq(clients.clientId, clientId))
|
||||||
|
.limit(1);
|
||||||
|
|
||||||
|
if (!client) {
|
||||||
|
logger.warn("Client not found");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { siteId, resourceId, chainId } = message.data;
|
||||||
|
|
||||||
|
let site: Site | null = null;
|
||||||
|
if (siteId) {
|
||||||
|
// get the site
|
||||||
|
const [siteRes] = await db
|
||||||
|
.select()
|
||||||
|
.from(sites)
|
||||||
|
.where(eq(sites.siteId, siteId))
|
||||||
|
.limit(1);
|
||||||
|
if (siteRes) {
|
||||||
|
site = siteRes;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (resourceId && !site) {
|
||||||
|
const resources = await db
|
||||||
|
.select()
|
||||||
|
.from(siteResources)
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
or(
|
||||||
|
eq(siteResources.niceId, resourceId),
|
||||||
|
eq(siteResources.alias, resourceId)
|
||||||
|
),
|
||||||
|
eq(siteResources.orgId, client.orgId)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!resources || resources.length === 0) {
|
||||||
|
logger.error(`handleOlmServerPeerAddMessage: Resource not found`);
|
||||||
|
// cancel the request from the olm side to not keep doing this
|
||||||
|
await sendToClient(
|
||||||
|
olm.olmId,
|
||||||
|
{
|
||||||
|
type: "olm/wg/peer/chain/cancel",
|
||||||
|
data: {
|
||||||
|
chainId
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ incrementConfigVersion: false }
|
||||||
|
).catch((error) => {
|
||||||
|
logger.warn(`Error sending message:`, error);
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (resources.length > 1) {
|
||||||
|
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
|
||||||
|
logger.error(
|
||||||
|
`handleOlmServerPeerAddMessage: Multiple resources found matching the criteria`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const resource = resources[0];
|
||||||
|
|
||||||
|
const currentResourceAssociationCaches = await db
|
||||||
|
.select()
|
||||||
|
.from(clientSiteResourcesAssociationsCache)
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
eq(
|
||||||
|
clientSiteResourcesAssociationsCache.siteResourceId,
|
||||||
|
resource.siteResourceId
|
||||||
|
),
|
||||||
|
eq(
|
||||||
|
clientSiteResourcesAssociationsCache.clientId,
|
||||||
|
client.clientId
|
||||||
|
)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
if (currentResourceAssociationCaches.length === 0) {
|
||||||
|
logger.error(
|
||||||
|
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
|
||||||
|
);
|
||||||
|
// cancel the request from the olm side to not keep doing this
|
||||||
|
await sendToClient(
|
||||||
|
olm.olmId,
|
||||||
|
{
|
||||||
|
type: "olm/wg/peer/chain/cancel",
|
||||||
|
data: {
|
||||||
|
chainId
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ incrementConfigVersion: false }
|
||||||
|
).catch((error) => {
|
||||||
|
logger.warn(`Error sending message:`, error);
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const siteIdFromResource = resource.siteId;
|
||||||
|
|
||||||
|
// get the site
|
||||||
|
const [siteRes] = await db
|
||||||
|
.select()
|
||||||
|
.from(sites)
|
||||||
|
.where(eq(sites.siteId, siteIdFromResource));
|
||||||
|
if (!siteRes) {
|
||||||
|
logger.error(
|
||||||
|
`handleOlmServerPeerAddMessage: Site with ID ${site} not found`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
site = siteRes;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!site) {
|
||||||
|
logger.error(`handleOlmServerPeerAddMessage: Site not found`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the client can access this site using the cache
|
||||||
|
const currentSiteAssociationCaches = await db
|
||||||
|
.select()
|
||||||
|
.from(clientSitesAssociationsCache)
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
eq(clientSitesAssociationsCache.clientId, client.clientId),
|
||||||
|
eq(clientSitesAssociationsCache.siteId, site.siteId)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
if (currentSiteAssociationCaches.length === 0) {
|
||||||
|
logger.error(
|
||||||
|
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}`
|
||||||
|
);
|
||||||
|
// cancel the request from the olm side to not keep doing this
|
||||||
|
await sendToClient(
|
||||||
|
olm.olmId,
|
||||||
|
{
|
||||||
|
type: "olm/wg/peer/chain/cancel",
|
||||||
|
data: {
|
||||||
|
chainId
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ incrementConfigVersion: false }
|
||||||
|
).catch((error) => {
|
||||||
|
logger.warn(`Error sending message:`, error);
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!site.exitNodeId) {
|
||||||
|
logger.error(
|
||||||
|
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
|
||||||
|
);
|
||||||
|
// cancel the request from the olm side to not keep doing this
|
||||||
|
await sendToClient(
|
||||||
|
olm.olmId,
|
||||||
|
{
|
||||||
|
type: "olm/wg/peer/chain/cancel",
|
||||||
|
data: {
|
||||||
|
chainId
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ incrementConfigVersion: false }
|
||||||
|
).catch((error) => {
|
||||||
|
logger.warn(`Error sending message:`, error);
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// get the exit node from the side
|
||||||
|
const [exitNode] = await db
|
||||||
|
.select()
|
||||||
|
.from(exitNodes)
|
||||||
|
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
|
||||||
|
|
||||||
|
if (!exitNode) {
|
||||||
|
logger.error(
|
||||||
|
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
|
||||||
|
// if it has already been added this will be a no-op
|
||||||
|
await initPeerAddHandshake(
|
||||||
|
// this will kick off the add peer process for the client
|
||||||
|
client.clientId,
|
||||||
|
{
|
||||||
|
siteId: site.siteId,
|
||||||
|
exitNode: {
|
||||||
|
publicKey: exitNode.publicKey,
|
||||||
|
endpoint: exitNode.endpoint
|
||||||
|
}
|
||||||
|
},
|
||||||
|
olm.olmId,
|
||||||
|
chainId
|
||||||
|
);
|
||||||
|
|
||||||
|
return;
|
||||||
|
};
|
||||||
@@ -54,7 +54,7 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { siteId } = message.data;
|
const { siteId, chainId } = message.data;
|
||||||
|
|
||||||
// get the site
|
// get the site
|
||||||
const [site] = await db
|
const [site] = await db
|
||||||
@@ -179,7 +179,8 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
|
|||||||
),
|
),
|
||||||
aliases: generateAliasConfig(
|
aliases: generateAliasConfig(
|
||||||
allSiteResources.map(({ siteResources }) => siteResources)
|
allSiteResources.map(({ siteResources }) => siteResources)
|
||||||
)
|
),
|
||||||
|
chainId: chainId,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
broadcast: false,
|
broadcast: false,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!olm.clientId) {
|
if (!olm.clientId) {
|
||||||
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
|
logger.warn("Olm has no client!");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { siteId } = message.data;
|
const { siteId, chainId } = message.data;
|
||||||
|
|
||||||
// Get the site
|
// Get the site
|
||||||
const [site] = await db
|
const [site] = await db
|
||||||
@@ -87,7 +87,8 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
|
|||||||
type: "olm/wg/peer/unrelay",
|
type: "olm/wg/peer/unrelay",
|
||||||
data: {
|
data: {
|
||||||
siteId: siteId,
|
siteId: siteId,
|
||||||
endpoint: site.endpoint
|
endpoint: site.endpoint,
|
||||||
|
chainId
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
broadcast: false,
|
broadcast: false,
|
||||||
|
|||||||
@@ -11,3 +11,4 @@ export * from "./handleOlmServerPeerAddMessage";
|
|||||||
export * from "./handleOlmUnRelayMessage";
|
export * from "./handleOlmUnRelayMessage";
|
||||||
export * from "./recoverOlmWithFingerprint";
|
export * from "./recoverOlmWithFingerprint";
|
||||||
export * from "./handleOlmDisconnectingMessage";
|
export * from "./handleOlmDisconnectingMessage";
|
||||||
|
export * from "./handleOlmServerInitAddPeerHandshake";
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import { sendToClient } from "#dynamic/routers/ws";
|
import { sendToClient } from "#dynamic/routers/ws";
|
||||||
import { db, olms } from "@server/db";
|
import { clientSitesAssociationsCache, db, olms } from "@server/db";
|
||||||
|
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||||
import config from "@server/lib/config";
|
import config from "@server/lib/config";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import { eq } from "drizzle-orm";
|
import { and, eq } from "drizzle-orm";
|
||||||
import { Alias } from "yaml";
|
import { Alias } from "yaml";
|
||||||
|
|
||||||
export async function addPeer(
|
export async function addPeer(
|
||||||
@@ -18,7 +19,8 @@ export async function addPeer(
|
|||||||
remoteSubnets: string[] | null; // optional, comma-separated list of subnets that this site can access
|
remoteSubnets: string[] | null; // optional, comma-separated list of subnets that this site can access
|
||||||
aliases: Alias[];
|
aliases: Alias[];
|
||||||
},
|
},
|
||||||
olmId?: string
|
olmId?: string,
|
||||||
|
version?: string | null
|
||||||
) {
|
) {
|
||||||
if (!olmId) {
|
if (!olmId) {
|
||||||
const [olm] = await db
|
const [olm] = await db
|
||||||
@@ -30,6 +32,7 @@ export async function addPeer(
|
|||||||
return; // ignore this because an olm might not be associated with the client anymore
|
return; // ignore this because an olm might not be associated with the client anymore
|
||||||
}
|
}
|
||||||
olmId = olm.olmId;
|
olmId = olm.olmId;
|
||||||
|
version = olm.version;
|
||||||
}
|
}
|
||||||
|
|
||||||
await sendToClient(
|
await sendToClient(
|
||||||
@@ -48,7 +51,7 @@ export async function addPeer(
|
|||||||
aliases: peer.aliases
|
aliases: peer.aliases
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{ incrementConfigVersion: true }
|
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||||
).catch((error) => {
|
).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
@@ -60,7 +63,8 @@ export async function deletePeer(
|
|||||||
clientId: number,
|
clientId: number,
|
||||||
siteId: number,
|
siteId: number,
|
||||||
publicKey: string,
|
publicKey: string,
|
||||||
olmId?: string
|
olmId?: string,
|
||||||
|
version?: string | null
|
||||||
) {
|
) {
|
||||||
if (!olmId) {
|
if (!olmId) {
|
||||||
const [olm] = await db
|
const [olm] = await db
|
||||||
@@ -72,6 +76,7 @@ export async function deletePeer(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
olmId = olm.olmId;
|
olmId = olm.olmId;
|
||||||
|
version = olm.version;
|
||||||
}
|
}
|
||||||
|
|
||||||
await sendToClient(
|
await sendToClient(
|
||||||
@@ -83,7 +88,7 @@ export async function deletePeer(
|
|||||||
siteId: siteId
|
siteId: siteId
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{ incrementConfigVersion: true }
|
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||||
).catch((error) => {
|
).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
@@ -103,7 +108,8 @@ export async function updatePeer(
|
|||||||
remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that
|
remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that
|
||||||
aliases?: Alias[] | null;
|
aliases?: Alias[] | null;
|
||||||
},
|
},
|
||||||
olmId?: string
|
olmId?: string,
|
||||||
|
version?: string | null
|
||||||
) {
|
) {
|
||||||
if (!olmId) {
|
if (!olmId) {
|
||||||
const [olm] = await db
|
const [olm] = await db
|
||||||
@@ -115,6 +121,7 @@ export async function updatePeer(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
olmId = olm.olmId;
|
olmId = olm.olmId;
|
||||||
|
version = olm.version;
|
||||||
}
|
}
|
||||||
|
|
||||||
await sendToClient(
|
await sendToClient(
|
||||||
@@ -132,7 +139,7 @@ export async function updatePeer(
|
|||||||
aliases: peer.aliases
|
aliases: peer.aliases
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{ incrementConfigVersion: true }
|
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||||
).catch((error) => {
|
).catch((error) => {
|
||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
@@ -149,7 +156,8 @@ export async function initPeerAddHandshake(
|
|||||||
endpoint: string;
|
endpoint: string;
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
olmId?: string
|
olmId?: string,
|
||||||
|
chainId?: string
|
||||||
) {
|
) {
|
||||||
if (!olmId) {
|
if (!olmId) {
|
||||||
const [olm] = await db
|
const [olm] = await db
|
||||||
@@ -173,7 +181,8 @@ export async function initPeerAddHandshake(
|
|||||||
publicKey: peer.exitNode.publicKey,
|
publicKey: peer.exitNode.publicKey,
|
||||||
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
||||||
endpoint: peer.exitNode.endpoint
|
endpoint: peer.exitNode.endpoint
|
||||||
}
|
},
|
||||||
|
chainId
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{ incrementConfigVersion: true }
|
{ incrementConfigVersion: true }
|
||||||
@@ -181,6 +190,17 @@ export async function initPeerAddHandshake(
|
|||||||
logger.warn(`Error sending message:`, error);
|
logger.warn(`Error sending message:`, error);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// update the clientSiteAssociationsCache to make the isJitMode flag false so that JIT mode is disabled for this site if it restarts or something after the connection
|
||||||
|
await db
|
||||||
|
.update(clientSitesAssociationsCache)
|
||||||
|
.set({ isJitMode: false })
|
||||||
|
.where(
|
||||||
|
and(
|
||||||
|
eq(clientSitesAssociationsCache.clientId, clientId),
|
||||||
|
eq(clientSitesAssociationsCache.siteId, peer.siteId)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`
|
`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
import { Client, db, exitNodes, Olm, sites, clientSitesAssociationsCache } from "@server/db";
|
import {
|
||||||
|
Client,
|
||||||
|
db,
|
||||||
|
exitNodes,
|
||||||
|
Olm,
|
||||||
|
sites,
|
||||||
|
clientSitesAssociationsCache
|
||||||
|
} from "@server/db";
|
||||||
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
|
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
|
||||||
import { sendToClient } from "#dynamic/routers/ws";
|
import { sendToClient } from "#dynamic/routers/ws";
|
||||||
import logger from "@server/logger";
|
import logger from "@server/logger";
|
||||||
import { eq, inArray } from "drizzle-orm";
|
import { eq, inArray } from "drizzle-orm";
|
||||||
import config from "@server/lib/config";
|
import config from "@server/lib/config";
|
||||||
|
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||||
|
|
||||||
export async function sendOlmSyncMessage(olm: Olm, client: Client) {
|
export async function sendOlmSyncMessage(olm: Olm, client: Client) {
|
||||||
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
|
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
|
||||||
@@ -17,10 +25,7 @@ export async function sendOlmSyncMessage(olm: Olm, client: Client) {
|
|||||||
const clientSites = await db
|
const clientSites = await db
|
||||||
.select()
|
.select()
|
||||||
.from(clientSitesAssociationsCache)
|
.from(clientSitesAssociationsCache)
|
||||||
.innerJoin(
|
.innerJoin(sites, eq(sites.siteId, clientSitesAssociationsCache.siteId))
|
||||||
sites,
|
|
||||||
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
|
||||||
)
|
|
||||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||||
|
|
||||||
// Extract unique exit node IDs
|
// Extract unique exit node IDs
|
||||||
@@ -68,13 +73,20 @@ export async function sendOlmSyncMessage(olm: Olm, client: Client) {
|
|||||||
|
|
||||||
logger.debug("sendOlmSyncMessage: sending sync message");
|
logger.debug("sendOlmSyncMessage: sending sync message");
|
||||||
|
|
||||||
await sendToClient(olm.olmId, {
|
await sendToClient(
|
||||||
|
olm.olmId,
|
||||||
|
{
|
||||||
type: "olm/sync",
|
type: "olm/sync",
|
||||||
data: {
|
data: {
|
||||||
sites: siteConfigurations,
|
sites: siteConfigurations,
|
||||||
exitNodes: exitNodesData
|
exitNodes: exitNodesData
|
||||||
}
|
}
|
||||||
}).catch((error) => {
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
compress: canCompress(olm.version, "olm")
|
||||||
|
}
|
||||||
|
).catch((error) => {
|
||||||
logger.warn(`Error sending olm sync message:`, error);
|
logger.warn(`Error sending olm sync message:`, error);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -620,7 +620,7 @@ export async function handleMessagingForUpdatedSiteResource(
|
|||||||
await updateTargets(newt.newtId, {
|
await updateTargets(newt.newtId, {
|
||||||
oldTargets: oldTargets,
|
oldTargets: oldTargets,
|
||||||
newTargets: newTargets
|
newTargets: newTargets
|
||||||
});
|
}, newt.version);
|
||||||
}
|
}
|
||||||
|
|
||||||
const olmJobs: Promise<void>[] = [];
|
const olmJobs: Promise<void>[] = [];
|
||||||
|
|||||||
@@ -264,7 +264,7 @@ export async function createTarget(
|
|||||||
newTarget,
|
newTarget,
|
||||||
healthCheck,
|
healthCheck,
|
||||||
resource.protocol,
|
resource.protocol,
|
||||||
resource.proxyPort
|
newt.version
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -262,7 +262,7 @@ export async function updateTarget(
|
|||||||
[updatedTarget],
|
[updatedTarget],
|
||||||
[updatedHc],
|
[updatedHc],
|
||||||
resource.protocol,
|
resource.protocol,
|
||||||
resource.proxyPort
|
newt.version
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ import {
|
|||||||
handleDockerContainersMessage,
|
handleDockerContainersMessage,
|
||||||
handleNewtPingRequestMessage,
|
handleNewtPingRequestMessage,
|
||||||
handleApplyBlueprintMessage,
|
handleApplyBlueprintMessage,
|
||||||
handleNewtPingMessage
|
handleNewtPingMessage,
|
||||||
|
startNewtOfflineChecker
|
||||||
} from "../newt";
|
} from "../newt";
|
||||||
import {
|
import {
|
||||||
handleOlmRegisterMessage,
|
handleOlmRegisterMessage,
|
||||||
@@ -15,7 +16,8 @@ import {
|
|||||||
startOlmOfflineChecker,
|
startOlmOfflineChecker,
|
||||||
handleOlmServerPeerAddMessage,
|
handleOlmServerPeerAddMessage,
|
||||||
handleOlmUnRelayMessage,
|
handleOlmUnRelayMessage,
|
||||||
handleOlmDisconnecingMessage
|
handleOlmDisconnecingMessage,
|
||||||
|
handleOlmServerInitAddPeerHandshake
|
||||||
} from "../olm";
|
} from "../olm";
|
||||||
import { handleHealthcheckStatusMessage } from "../target";
|
import { handleHealthcheckStatusMessage } from "../target";
|
||||||
import { handleRoundTripMessage } from "./handleRoundTripMessage";
|
import { handleRoundTripMessage } from "./handleRoundTripMessage";
|
||||||
@@ -23,6 +25,7 @@ import { MessageHandler } from "./types";
|
|||||||
|
|
||||||
export const messageHandlers: Record<string, MessageHandler> = {
|
export const messageHandlers: Record<string, MessageHandler> = {
|
||||||
"olm/wg/server/peer/add": handleOlmServerPeerAddMessage,
|
"olm/wg/server/peer/add": handleOlmServerPeerAddMessage,
|
||||||
|
"olm/wg/server/peer/init": handleOlmServerInitAddPeerHandshake,
|
||||||
"olm/wg/register": handleOlmRegisterMessage,
|
"olm/wg/register": handleOlmRegisterMessage,
|
||||||
"olm/wg/relay": handleOlmRelayMessage,
|
"olm/wg/relay": handleOlmRelayMessage,
|
||||||
"olm/wg/unrelay": handleOlmUnRelayMessage,
|
"olm/wg/unrelay": handleOlmUnRelayMessage,
|
||||||
@@ -41,3 +44,4 @@ export const messageHandlers: Record<string, MessageHandler> = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
startOlmOfflineChecker(); // this is to handle the offline check for olms
|
startOlmOfflineChecker(); // this is to handle the offline check for olms
|
||||||
|
startNewtOfflineChecker(); // this is to handle the offline check for newts
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ export interface AuthenticatedWebSocket extends WebSocket {
|
|||||||
clientType?: ClientType;
|
clientType?: ClientType;
|
||||||
connectionId?: string;
|
connectionId?: string;
|
||||||
isFullyConnected?: boolean;
|
isFullyConnected?: boolean;
|
||||||
pendingMessages?: Buffer[];
|
pendingMessages?: { data: Buffer; isBinary: boolean }[];
|
||||||
configVersion?: number;
|
configVersion?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -73,6 +73,7 @@ export type MessageHandler = (
|
|||||||
// Options for sending messages with config version tracking
|
// Options for sending messages with config version tracking
|
||||||
export interface SendMessageOptions {
|
export interface SendMessageOptions {
|
||||||
incrementConfigVersion?: boolean;
|
incrementConfigVersion?: boolean;
|
||||||
|
compress?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Redis message type for cross-node communication
|
// Redis message type for cross-node communication
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
import { Router, Request, Response } from "express";
|
import { Router, Request, Response } from "express";
|
||||||
|
import zlib from "zlib";
|
||||||
import { Server as HttpServer } from "http";
|
import { Server as HttpServer } from "http";
|
||||||
import { WebSocket, WebSocketServer } from "ws";
|
import { WebSocket, WebSocketServer } from "ws";
|
||||||
import { Socket } from "net";
|
import { Socket } from "net";
|
||||||
import { Newt, newts, NewtSession, olms, Olm, OlmSession } from "@server/db";
|
import { Newt, newts, NewtSession, olms, Olm, OlmSession, sites } from "@server/db";
|
||||||
import { eq } from "drizzle-orm";
|
import { eq } from "drizzle-orm";
|
||||||
import { db } from "@server/db";
|
import { db } from "@server/db";
|
||||||
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
|
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
|
||||||
@@ -116,11 +117,20 @@ const sendToClientLocal = async (
|
|||||||
};
|
};
|
||||||
|
|
||||||
const messageString = JSON.stringify(messageWithVersion);
|
const messageString = JSON.stringify(messageWithVersion);
|
||||||
|
if (options.compress) {
|
||||||
|
const compressed = zlib.gzipSync(Buffer.from(messageString, "utf8"));
|
||||||
|
clients.forEach((client) => {
|
||||||
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
|
client.send(compressed);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
clients.forEach((client) => {
|
clients.forEach((client) => {
|
||||||
if (client.readyState === WebSocket.OPEN) {
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
client.send(messageString);
|
client.send(messageString);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -147,12 +157,23 @@ const broadcastToAllExceptLocal = async (
|
|||||||
...message,
|
...message,
|
||||||
configVersion
|
configVersion
|
||||||
};
|
};
|
||||||
|
if (options.compress) {
|
||||||
|
const compressed = zlib.gzipSync(
|
||||||
|
Buffer.from(JSON.stringify(messageWithVersion), "utf8")
|
||||||
|
);
|
||||||
|
clients.forEach((client) => {
|
||||||
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
|
client.send(compressed);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
clients.forEach((client) => {
|
clients.forEach((client) => {
|
||||||
if (client.readyState === WebSocket.OPEN) {
|
if (client.readyState === WebSocket.OPEN) {
|
||||||
client.send(JSON.stringify(messageWithVersion));
|
client.send(JSON.stringify(messageWithVersion));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -286,9 +307,12 @@ const setupConnection = async (
|
|||||||
clientType === "newt" ? (client as Newt).newtId : (client as Olm).olmId;
|
clientType === "newt" ? (client as Newt).newtId : (client as Olm).olmId;
|
||||||
await addClient(clientType, clientId, ws);
|
await addClient(clientType, clientId, ws);
|
||||||
|
|
||||||
ws.on("message", async (data) => {
|
ws.on("message", async (data, isBinary) => {
|
||||||
try {
|
try {
|
||||||
const message: WSMessage = JSON.parse(data.toString());
|
const messageBuffer = isBinary
|
||||||
|
? zlib.gunzipSync(data as Buffer)
|
||||||
|
: (data as Buffer);
|
||||||
|
const message: WSMessage = JSON.parse(messageBuffer.toString());
|
||||||
|
|
||||||
if (!message.type || typeof message.type !== "string") {
|
if (!message.type || typeof message.type !== "string") {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
@@ -356,6 +380,31 @@ const setupConnection = async (
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Handle WebSocket protocol-level pings from older newt clients that do
|
||||||
|
// not send application-level "newt/ping" messages. Update the site's
|
||||||
|
// online state and lastPing timestamp so the offline checker treats them
|
||||||
|
// the same as modern newt clients.
|
||||||
|
if (clientType === "newt") {
|
||||||
|
const newtClient = client as Newt;
|
||||||
|
ws.on("ping", async () => {
|
||||||
|
if (!newtClient.siteId) return;
|
||||||
|
try {
|
||||||
|
await db
|
||||||
|
.update(sites)
|
||||||
|
.set({
|
||||||
|
online: true,
|
||||||
|
lastPing: Math.floor(Date.now() / 1000)
|
||||||
|
})
|
||||||
|
.where(eq(sites.siteId, newtClient.siteId));
|
||||||
|
} catch (error) {
|
||||||
|
logger.error(
|
||||||
|
"Error updating newt site online state on WS ping",
|
||||||
|
{ error }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
ws.on("error", (error: Error) => {
|
ws.on("error", (error: Error) => {
|
||||||
logger.error(
|
logger.error(
|
||||||
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,
|
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,
|
||||||
|
|||||||
Reference in New Issue
Block a user