Merge branch 'dev' into msg-opt

This commit is contained in:
Owen
2026-03-23 16:00:50 -07:00
203 changed files with 4955 additions and 4932 deletions

View File

@@ -43,7 +43,7 @@ registry.registerPath({
method: "post",
path: "/resource/{resourceId}/access-token",
description: "Generate a new access token for a resource.",
tags: [OpenAPITags.Resource, OpenAPITags.AccessToken],
tags: [OpenAPITags.PublicResource, OpenAPITags.AccessToken],
request: {
params: generateAccssTokenParamsSchema,
body: {

View File

@@ -122,7 +122,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/access-tokens",
description: "List all access tokens in an organization.",
tags: [OpenAPITags.Org, OpenAPITags.AccessToken],
tags: [OpenAPITags.AccessToken],
request: {
params: z.object({
orgId: z.string()
@@ -135,8 +135,8 @@ registry.registerPath({
registry.registerPath({
method: "get",
path: "/resource/{resourceId}/access-tokens",
description: "List all access tokens in an organization.",
tags: [OpenAPITags.Resource, OpenAPITags.AccessToken],
description: "List all access tokens for a resource.",
tags: [OpenAPITags.PublicResource, OpenAPITags.AccessToken],
request: {
params: z.object({
resourceId: z.number()

View File

@@ -37,7 +37,7 @@ registry.registerPath({
method: "put",
path: "/org/{orgId}/api-key",
description: "Create a new API key scoped to the organization.",
tags: [OpenAPITags.Org, OpenAPITags.ApiKey],
tags: [OpenAPITags.ApiKey],
request: {
params: paramsSchema,
body: {

View File

@@ -18,7 +18,7 @@ registry.registerPath({
method: "delete",
path: "/org/{orgId}/api-key/{apiKeyId}",
description: "Delete an API key.",
tags: [OpenAPITags.Org, OpenAPITags.ApiKey],
tags: [OpenAPITags.ApiKey],
request: {
params: paramsSchema
},

View File

@@ -48,7 +48,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/api-key/{apiKeyId}/actions",
description: "List all actions set for an API key.",
tags: [OpenAPITags.Org, OpenAPITags.ApiKey],
tags: [OpenAPITags.ApiKey],
request: {
params: paramsSchema,
query: querySchema

View File

@@ -52,7 +52,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/api-keys",
description: "List all API keys for an organization",
tags: [OpenAPITags.Org, OpenAPITags.ApiKey],
tags: [OpenAPITags.ApiKey],
request: {
params: paramsSchema,
query: querySchema

View File

@@ -25,7 +25,7 @@ registry.registerPath({
path: "/org/{orgId}/api-key/{apiKeyId}/actions",
description:
"Set actions for an API key. This will replace any existing actions.",
tags: [OpenAPITags.Org, OpenAPITags.ApiKey],
tags: [OpenAPITags.ApiKey],
request: {
params: paramsSchema,
body: {

View File

@@ -20,7 +20,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/logs/request",
description: "Query the request audit log for an organization",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Logs],
request: {
query: queryAccessAuditLogsQuery.omit({
limit: true,

View File

@@ -151,7 +151,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/logs/analytics",
description: "Query the request audit analytics for an organization",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Logs],
request: {
query: queryAccessAuditLogsQuery,
params: queryRequestAuditLogsParams

View File

@@ -182,7 +182,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/logs/request",
description: "Query the request audit log for an organization",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Logs],
request: {
query: queryAccessAuditLogsQuery,
params: queryRequestAuditLogsParams

View File

@@ -1,5 +1,5 @@
import { NextFunction, Request, Response } from "express";
import { db, users } from "@server/db";
import { bannedEmails, bannedIps, db, users } from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { email, z } from "zod";
import { fromError } from "zod-validation-error";
@@ -22,7 +22,6 @@ import { checkValidInvite } from "@server/auth/checkValidInvite";
import { passwordSchema } from "@server/auth/passwordSchema";
import { UserType } from "@server/types/UserTypes";
import { build } from "@server/build";
import resend, { AudienceIds, moveEmailToAudience } from "#dynamic/lib/resend";
export const signupBodySchema = z.object({
email: z.email().toLowerCase(),
@@ -66,6 +65,30 @@ export async function signup(
skipVerificationEmail
} = parsedBody.data;
const [bannedEmail] = await db
.select()
.from(bannedEmails)
.where(eq(bannedEmails.email, email))
.limit(1);
if (bannedEmail) {
return next(
createHttpError(HttpCode.FORBIDDEN, "Signup blocked. Do not attempt to continue to use this service.")
);
}
if (req.ip) {
const [bannedIp] = await db
.select()
.from(bannedIps)
.where(eq(bannedIps.ip, req.ip))
.limit(1);
if (bannedIp) {
return next(
createHttpError(HttpCode.FORBIDDEN, "Signup blocked. Do not attempt to continue to use this service.")
);
}
}
const passwordHash = await hashPassword(password);
const userId = generateId(15);
@@ -189,6 +212,7 @@ export async function signup(
dateCreated: moment().toISOString(),
termsAcceptedTimestamp: termsAcceptedTimestamp || null,
termsVersion: "1",
marketingEmailConsent: marketingEmailConsent ?? false,
lastPasswordChange: new Date().getTime()
});
@@ -212,7 +236,7 @@ export async function signup(
logger.debug(
`User ${email} opted in to marketing emails during signup.`
);
moveEmailToAudience(email, AudienceIds.SignUps);
// TODO: update user in Sendy
}
if (config.getRawConfig().flags?.require_email_verification) {

View File

@@ -5,6 +5,8 @@ import cache from "#dynamic/lib/cache";
import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs";
import { stripPortFromHost } from "@server/lib/ip";
import { sanitizeString } from "@server/lib/sanitize";
/**
Reasons:
@@ -253,24 +255,23 @@ export async function logRequestAudit(
// Add to buffer instead of writing directly to DB
auditLogBuffer.push({
timestamp,
orgId: data.orgId,
actorType,
actor,
actorId,
metadata,
orgId: sanitizeString(data.orgId),
actorType: sanitizeString(actorType),
actor: sanitizeString(actor),
actorId: sanitizeString(actorId),
metadata: sanitizeString(metadata),
action: data.action,
resourceId: data.resourceId,
reason: data.reason,
location: data.location,
originalRequestURL: body.originalRequestURL,
scheme: body.scheme,
host: body.host,
path: body.path,
method: body.method,
ip: clientIp,
location: sanitizeString(data.location),
originalRequestURL: sanitizeString(body.originalRequestURL) ?? "",
scheme: sanitizeString(body.scheme) ?? "",
host: sanitizeString(body.host) ?? "",
path: sanitizeString(body.path) ?? "",
method: sanitizeString(body.method) ?? "",
ip: sanitizeString(clientIp),
tls: body.tls
});
// Flush immediately if buffer is full, otherwise schedule a flush
if (auditLogBuffer.length >= BATCH_SIZE) {
// Fire and forget - don't block the caller

View File

@@ -20,7 +20,7 @@ registry.registerPath({
method: "put",
path: "/org/{orgId}/blueprint",
description: "Apply a base64 encoded JSON blueprint to an organization",
tags: [OpenAPITags.Org, OpenAPITags.Blueprint],
tags: [OpenAPITags.Blueprint],
request: {
params: applyBlueprintParamsSchema,
body: {

View File

@@ -43,7 +43,7 @@ registry.registerPath({
method: "put",
path: "/org/{orgId}/blueprint",
description: "Create and apply a YAML blueprint to an organization",
tags: [OpenAPITags.Org, OpenAPITags.Blueprint],
tags: [OpenAPITags.Blueprint],
request: {
params: applyBlueprintParamsSchema,
body: {

View File

@@ -53,7 +53,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/blueprint/{blueprintId}",
description: "Get a blueprint by its blueprint ID.",
tags: [OpenAPITags.Org, OpenAPITags.Blueprint],
tags: [OpenAPITags.Blueprint],
request: {
params: getBlueprintSchema
},

View File

@@ -67,7 +67,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/blueprints",
description: "List all blueprints for a organization.",
tags: [OpenAPITags.Org, OpenAPITags.Blueprint],
tags: [OpenAPITags.Blueprint],
request: {
params: z.object({
orgId: z.string()

View File

@@ -48,7 +48,7 @@ registry.registerPath({
method: "put",
path: "/org/{orgId}/client",
description: "Create a new client for an organization.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
tags: [OpenAPITags.Client],
request: {
params: createClientParamsSchema,
body: {

View File

@@ -49,7 +49,7 @@ registry.registerPath({
path: "/org/{orgId}/user/{userId}/client",
description:
"Create a new client for a user and associate it with an existing olm.",
tags: [OpenAPITags.Client, OpenAPITags.Org, OpenAPITags.User],
tags: [OpenAPITags.Client],
request: {
params: paramsSchema,
body: {

View File

@@ -243,7 +243,7 @@ registry.registerPath({
path: "/org/{orgId}/client/{niceId}",
description:
"Get a client by orgId and niceId. NiceId is a readable ID for the site and unique on a per org basis.",
tags: [OpenAPITags.Org, OpenAPITags.Site],
tags: [OpenAPITags.Site],
request: {
params: z.object({
orgId: z.string(),

View File

@@ -70,7 +70,7 @@ async function getLatestOlmVersion(): Promise<string | null> {
tags = tags.filter((version) => !version.name.includes("rc"));
const latestVersion = tags[0].name;
olmVersionCache.set("latestOlmVersion", latestVersion);
olmVersionCache.set("latestOlmVersion", latestVersion, 3600);
return latestVersion;
} catch (error: any) {
@@ -237,7 +237,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/clients",
description: "List all clients for an organization.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
tags: [OpenAPITags.Client],
request: {
query: listClientsSchema,
params: listClientsParamsSchema

View File

@@ -71,7 +71,7 @@ async function getLatestOlmVersion(): Promise<string | null> {
tags = tags.filter((version) => !version.name.includes("rc"));
const latestVersion = tags[0].name;
olmVersionCache.set("latestOlmVersion", latestVersion);
olmVersionCache.set("latestOlmVersion", latestVersion, 3600);
return latestVersion;
} catch (error: any) {
@@ -256,7 +256,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/user-devices",
description: "List all user devices for an organization.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
tags: [OpenAPITags.Client],
request: {
query: listUserDevicesSchema,
params: listUserDevicesParamsSchema

View File

@@ -23,7 +23,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/pick-client-defaults",
description: "Return pre-requisite data for creating a client.",
tags: [OpenAPITags.Client, OpenAPITags.Site],
tags: [OpenAPITags.Client],
request: {
params: pickClientDefaultsSchema
},

View File

@@ -1,31 +1,16 @@
import { sendToClient } from "#dynamic/routers/ws";
import { S } from "@faker-js/faker/dist/airline-Dz1uGqgJ";
import { db, newts, olms, Transaction } from "@server/db";
import { db, newts, olms } from "@server/db";
import {
Alias,
convertSubnetProxyTargetsV2ToV1,
SubnetProxyTarget,
SubnetProxyTargetV2
} from "@server/lib/ip";
import { canCompress } from "@server/lib/clientVersionChecks";
import logger from "@server/logger";
import { eq } from "drizzle-orm";
import semver from "semver";
const BATCH_SIZE = 50;
const BATCH_DELAY_MS = 50;
function sleep(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
function chunkArray<T>(array: T[], size: number): T[][] {
const chunks: T[][] = [];
for (let i = 0; i < array.length; i += size) {
chunks.push(array.slice(i, i + size));
}
return chunks;
}
const NEWT_V2_TARGETS_VERSION = ">=1.11.0";
export async function convertTargetsIfNessicary(
@@ -59,53 +44,36 @@ export async function convertTargetsIfNessicary(
export async function addTargets(
newtId: string,
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[]
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[],
version?: string | null
) {
targets = await convertTargetsIfNessicary(newtId, targets);
const batches = chunkArray<SubnetProxyTarget | SubnetProxyTargetV2>(
targets,
BATCH_SIZE
await sendToClient(
newtId,
{
type: `newt/wg/targets/add`,
data: targets
},
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
);
for (let i = 0; i < batches.length; i++) {
if (i > 0) {
await sleep(BATCH_DELAY_MS);
}
await sendToClient(
newtId,
{
type: `newt/wg/targets/add`,
data: batches[i]
},
{ incrementConfigVersion: true }
);
}
}
export async function removeTargets(
newtId: string,
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[]
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[],
version?: string | null
) {
targets = await convertTargetsIfNessicary(newtId, targets);
const batches = chunkArray<SubnetProxyTarget | SubnetProxyTargetV2>(
targets,
BATCH_SIZE
await sendToClient(
newtId,
{
type: `newt/wg/targets/remove`,
data: targets
},
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
);
for (let i = 0; i < batches.length; i++) {
if (i > 0) {
await sleep(BATCH_DELAY_MS);
}
await sendToClient(
newtId,
{
type: `newt/wg/targets/remove`,
data: batches[i]
},
{ incrementConfigVersion: true }
);
}
}
export async function updateTargets(
@@ -113,7 +81,8 @@ export async function updateTargets(
targets: {
oldTargets: SubnetProxyTarget[] | SubnetProxyTargetV2[];
newTargets: SubnetProxyTarget[] | SubnetProxyTargetV2[];
}
},
version?: string | null
) {
// get the newt
const [newt] = await db
@@ -143,35 +112,19 @@ export async function updateTargets(
};
}
const oldBatches = chunkArray<SubnetProxyTarget | SubnetProxyTargetV2>(
targets.oldTargets,
BATCH_SIZE
);
const newBatches = chunkArray<SubnetProxyTarget | SubnetProxyTargetV2>(
targets.newTargets,
BATCH_SIZE
);
const maxBatches = Math.max(oldBatches.length, newBatches.length);
for (let i = 0; i < maxBatches; i++) {
if (i > 0) {
await sleep(BATCH_DELAY_MS);
}
await sendToClient(
newtId,
{
type: `newt/wg/targets/update`,
data: {
oldTargets: oldBatches[i] || [],
newTargets: newBatches[i] || []
}
},
{ incrementConfigVersion: true }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
await sendToClient(
newtId,
{
type: `newt/wg/targets/update`,
data: {
oldTargets: targets.oldTargets,
newTargets: targets.newTargets
}
},
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
}
export async function addPeerData(
@@ -179,7 +132,8 @@ export async function addPeerData(
siteId: number,
remoteSubnets: string[],
aliases: Alias[],
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -191,6 +145,7 @@ export async function addPeerData(
return; // ignore this because an olm might not be associated with the client anymore
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(
@@ -203,7 +158,7 @@ export async function addPeerData(
aliases: aliases
}
},
{ incrementConfigVersion: true }
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
@@ -214,7 +169,8 @@ export async function removePeerData(
siteId: number,
remoteSubnets: string[],
aliases: Alias[],
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -226,6 +182,7 @@ export async function removePeerData(
return;
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(
@@ -238,7 +195,7 @@ export async function removePeerData(
aliases: aliases
}
},
{ incrementConfigVersion: true }
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
@@ -259,7 +216,8 @@ export async function updatePeerData(
newAliases: Alias[];
}
| undefined,
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -271,6 +229,7 @@ export async function updatePeerData(
return;
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(
@@ -283,7 +242,7 @@ export async function updatePeerData(
...aliases
}
},
{ incrementConfigVersion: true }
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});

View File

@@ -40,7 +40,8 @@ async function queryDomains(orgId: string, limit: number, offset: number) {
tries: domains.tries,
configManaged: domains.configManaged,
certResolver: domains.certResolver,
preferWildcardCert: domains.preferWildcardCert
preferWildcardCert: domains.preferWildcardCert,
errorMessage: domains.errorMessage
})
.from(orgDomains)
.where(eq(orgDomains.orgId, orgId))
@@ -59,7 +60,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/domains",
description: "List all domains for a organization.",
tags: [OpenAPITags.Org],
tags: [OpenAPITags.Domain],
request: {
params: z.object({
orgId: z.string()

View File

@@ -125,7 +125,7 @@ export async function generateRelayMappings(exitNode: ExitNode) {
// Add site as a destination for this client
const destination: PeerDestination = {
destinationIP: site.subnet.split("/")[0],
destinationPort: site.listenPort
destinationPort: site.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
};
// Check if this destination is already in the array to avoid duplicates
@@ -165,7 +165,7 @@ export async function generateRelayMappings(exitNode: ExitNode) {
const destination: PeerDestination = {
destinationIP: peer.subnet.split("/")[0],
destinationPort: peer.listenPort
destinationPort: peer.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
};
// Check for duplicates

View File

@@ -1,5 +1,5 @@
import { Request, Response, NextFunction } from "express";
import { eq, and, lt, inArray, sql } from "drizzle-orm";
import { eq, sql } from "drizzle-orm";
import { sites } from "@server/db";
import { db } from "@server/db";
import logger from "@server/logger";
@@ -11,19 +11,31 @@ import { FeatureId } from "@server/lib/billing/features";
import { checkExitNodeOrg } from "#dynamic/lib/exitNodes";
import { build } from "@server/build";
// Track sites that are already offline to avoid unnecessary queries
const offlineSites = new Set<string>();
// Retry configuration for deadlock handling
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50;
interface PeerBandwidth {
publicKey: string;
bytesIn: number;
bytesOut: number;
}
interface AccumulatorEntry {
bytesIn: number;
bytesOut: number;
/** Present when the update came through a remote exit node. */
exitNodeId?: number;
/** Whether to record egress usage for billing purposes. */
calcUsage: boolean;
}
// Retry configuration for deadlock handling
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50;
// How often to flush accumulated bandwidth data to the database
const FLUSH_INTERVAL_MS = 30_000; // 30 seconds
// In-memory accumulator: publicKey -> AccumulatorEntry
let accumulator = new Map<string, AccumulatorEntry>();
/**
* Check if an error is a deadlock error
*/
@@ -63,6 +75,220 @@ async function withDeadlockRetry<T>(
}
}
/**
* Flush all accumulated site bandwidth data to the database.
*
* Swaps out the accumulator before writing so that any bandwidth messages
* received during the flush are captured in the new accumulator rather than
* being lost or causing contention. Entries that fail to write are re-queued
* back into the accumulator so they will be retried on the next flush.
*
* This function is exported so that the application's graceful-shutdown
* cleanup handler can call it before the process exits.
*/
export async function flushSiteBandwidthToDb(): Promise<void> {
if (accumulator.size === 0) {
return;
}
// Atomically swap out the accumulator so new data keeps flowing in
// while we write the snapshot to the database.
const snapshot = accumulator;
accumulator = new Map<string, AccumulatorEntry>();
const currentTime = new Date().toISOString();
// Sort by publicKey for consistent lock ordering across concurrent
// writers — deadlock-prevention strategy.
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
a.localeCompare(b)
);
logger.debug(
`Flushing accumulated bandwidth data for ${sortedEntries.length} site(s) to the database`
);
// Aggregate billing usage by org, collected during the DB update loop.
const orgUsageMap = new Map<string, number>();
for (const [publicKey, { bytesIn, bytesOut, exitNodeId, calcUsage }] of sortedEntries) {
try {
const updatedSite = await withDeadlockRetry(async () => {
const [result] = await db
.update(sites)
.set({
megabytesOut: sql`COALESCE(${sites.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${sites.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime,
})
.where(eq(sites.pubKey, publicKey))
.returning({
orgId: sites.orgId,
siteId: sites.siteId
});
return result;
}, `flush bandwidth for site ${publicKey}`);
if (updatedSite) {
if (exitNodeId) {
const notAllowed = await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
);
if (notAllowed) {
logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
);
// Skip usage tracking for this site but continue
// processing the rest.
continue;
}
}
if (calcUsage) {
const totalBandwidth = bytesIn + bytesOut;
const current = orgUsageMap.get(updatedSite.orgId) ?? 0;
orgUsageMap.set(updatedSite.orgId, current + totalBandwidth);
}
}
} catch (error) {
logger.error(
`Failed to flush bandwidth for site ${publicKey}:`,
error
);
// Re-queue the failed entry so it is retried on the next flush
// rather than silently dropped.
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
} else {
accumulator.set(publicKey, {
bytesIn,
bytesOut,
exitNodeId,
calcUsage
});
}
}
}
// Process billing usage updates outside the site-update loop to keep
// lock scope small and concerns separated.
if (orgUsageMap.size > 0) {
// Sort org IDs for consistent lock ordering.
const sortedOrgIds = [...orgUsageMap.keys()].sort();
for (const orgId of sortedOrgIds) {
try {
const totalBandwidth = orgUsageMap.get(orgId)!;
const bandwidthUsage = await usageService.add(
orgId,
FeatureId.EGRESS_DATA_MB,
totalBandwidth
);
if (bandwidthUsage) {
// Fire-and-forget — don't block the flush on limit checking.
usageService
.checkLimitSet(
orgId,
FeatureId.EGRESS_DATA_MB,
bandwidthUsage
)
.catch((error: any) => {
logger.error(
`Error checking bandwidth limits for org ${orgId}:`,
error
);
});
}
} catch (error) {
logger.error(
`Error processing usage for org ${orgId}:`,
error
);
// Continue with other orgs.
}
}
}
}
// ---------------------------------------------------------------------------
// Periodic flush timer
// ---------------------------------------------------------------------------
const flushTimer = setInterval(async () => {
try {
await flushSiteBandwidthToDb();
} catch (error) {
logger.error(
"Unexpected error during periodic site bandwidth flush:",
error
);
}
}, FLUSH_INTERVAL_MS);
// Allow the process to exit normally even while the timer is pending.
// The graceful-shutdown path (see server/cleanup.ts) will call
// flushSiteBandwidthToDb() explicitly before process.exit(), so no data
// is lost.
flushTimer.unref();
// ---------------------------------------------------------------------------
// Public API
// ---------------------------------------------------------------------------
/**
* Accumulate bandwidth data reported by a gerbil or remote exit node.
*
* Only peers that actually transferred data (bytesIn > 0) are added to the
* accumulator; peers with no activity are silently ignored, which means the
* flush will only write rows that have genuinely changed.
*
* The function is intentionally synchronous in its fast path so that the
* HTTP handler can respond immediately without waiting for any I/O.
*/
export async function updateSiteBandwidth(
bandwidthData: PeerBandwidth[],
calcUsageAndLimits: boolean,
exitNodeId?: number
): Promise<void> {
for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
// Skip peers that haven't transferred any data — writing zeros to the
// database would be a no-op anyway.
if (bytesIn <= 0 && bytesOut <= 0) {
continue;
}
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
// Retain the most-recent exitNodeId for this peer.
if (exitNodeId !== undefined) {
existing.exitNodeId = exitNodeId;
}
// Once calcUsage has been requested for a peer, keep it set for
// the lifetime of this flush window.
if (calcUsageAndLimits) {
existing.calcUsage = true;
}
} else {
accumulator.set(publicKey, {
bytesIn,
bytesOut,
exitNodeId,
calcUsage: calcUsageAndLimits
});
}
}
}
// ---------------------------------------------------------------------------
// HTTP handler
// ---------------------------------------------------------------------------
export const receiveBandwidth = async (
req: Request,
res: Response,
@@ -75,7 +301,9 @@ export const receiveBandwidth = async (
throw new Error("Invalid bandwidth data");
}
await updateSiteBandwidth(bandwidthData, build == "saas"); // we are checking the usage on saas only
// Accumulate in memory; the periodic timer (and the shutdown hook)
// will write to the database.
await updateSiteBandwidth(bandwidthData, build == "saas");
return response(res, {
data: {},
@@ -94,201 +322,3 @@ export const receiveBandwidth = async (
);
}
};
export async function updateSiteBandwidth(
bandwidthData: PeerBandwidth[],
calcUsageAndLimits: boolean,
exitNodeId?: number
) {
const currentTime = new Date();
const oneMinuteAgo = new Date(currentTime.getTime() - 60000); // 1 minute ago
// Sort bandwidth data by publicKey to ensure consistent lock ordering across all instances
// This is critical for preventing deadlocks when multiple instances update the same sites
const sortedBandwidthData = [...bandwidthData].sort((a, b) =>
a.publicKey.localeCompare(b.publicKey)
);
// First, handle sites that are actively reporting bandwidth
const activePeers = sortedBandwidthData.filter((peer) => peer.bytesIn > 0);
// Aggregate usage data by organization (collected outside transaction)
const orgUsageMap = new Map<string, number>();
if (activePeers.length > 0) {
// Remove any active peers from offline tracking since they're sending data
activePeers.forEach((peer) => offlineSites.delete(peer.publicKey));
// Update each active site individually with retry logic
// This reduces transaction scope and allows retries per-site
for (const peer of activePeers) {
try {
const updatedSite = await withDeadlockRetry(async () => {
const [result] = await db
.update(sites)
.set({
megabytesOut: sql`${sites.megabytesOut} + ${peer.bytesIn}`,
megabytesIn: sql`${sites.megabytesIn} + ${peer.bytesOut}`,
lastBandwidthUpdate: currentTime.toISOString(),
online: true
})
.where(eq(sites.pubKey, peer.publicKey))
.returning({
online: sites.online,
orgId: sites.orgId,
siteId: sites.siteId,
lastBandwidthUpdate: sites.lastBandwidthUpdate
});
return result;
}, `update active site ${peer.publicKey}`);
if (updatedSite) {
if (exitNodeId) {
const notAllowed = await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
);
if (notAllowed) {
logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
);
// Skip this site but continue processing others
continue;
}
}
// Aggregate bandwidth usage for the org
const totalBandwidth = peer.bytesIn + peer.bytesOut;
const currentOrgUsage =
orgUsageMap.get(updatedSite.orgId) || 0;
orgUsageMap.set(
updatedSite.orgId,
currentOrgUsage + totalBandwidth
);
}
} catch (error) {
logger.error(
`Failed to update bandwidth for site ${peer.publicKey}:`,
error
);
// Continue with other sites
}
}
}
// Process usage updates outside of site update transactions
// This separates the concerns and reduces lock contention
if (calcUsageAndLimits && orgUsageMap.size > 0) {
// Sort org IDs to ensure consistent lock ordering
const allOrgIds = [...new Set([...orgUsageMap.keys()])].sort();
for (const orgId of allOrgIds) {
try {
// Process bandwidth usage for this org
const totalBandwidth = orgUsageMap.get(orgId);
if (totalBandwidth) {
const bandwidthUsage = await usageService.add(
orgId,
FeatureId.EGRESS_DATA_MB,
totalBandwidth
);
if (bandwidthUsage) {
// Fire and forget - don't block on limit checking
usageService
.checkLimitSet(
orgId,
FeatureId.EGRESS_DATA_MB,
bandwidthUsage
)
.catch((error: any) => {
logger.error(
`Error checking bandwidth limits for org ${orgId}:`,
error
);
});
}
}
} catch (error) {
logger.error(`Error processing usage for org ${orgId}:`, error);
// Continue with other orgs
}
}
}
// Handle sites that reported zero bandwidth but need online status updated
const zeroBandwidthPeers = sortedBandwidthData.filter(
(peer) => peer.bytesIn === 0 && !offlineSites.has(peer.publicKey)
);
if (zeroBandwidthPeers.length > 0) {
// Fetch all zero bandwidth sites in one query
const zeroBandwidthSites = await db
.select()
.from(sites)
.where(
inArray(
sites.pubKey,
zeroBandwidthPeers.map((p) => p.publicKey)
)
);
// Sort by siteId to ensure consistent lock ordering
const sortedZeroBandwidthSites = zeroBandwidthSites.sort(
(a, b) => a.siteId - b.siteId
);
for (const site of sortedZeroBandwidthSites) {
let newOnlineStatus = site.online;
// Check if site should go offline based on last bandwidth update WITH DATA
if (site.lastBandwidthUpdate) {
const lastUpdateWithData = new Date(site.lastBandwidthUpdate);
if (lastUpdateWithData < oneMinuteAgo) {
newOnlineStatus = false;
}
} else {
// No previous data update recorded, set to offline
newOnlineStatus = false;
}
// Only update online status if it changed
if (site.online !== newOnlineStatus) {
try {
const updatedSite = await withDeadlockRetry(async () => {
const [result] = await db
.update(sites)
.set({
online: newOnlineStatus
})
.where(eq(sites.siteId, site.siteId))
.returning();
return result;
}, `update offline status for site ${site.siteId}`);
if (updatedSite && exitNodeId) {
const notAllowed = await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
);
if (notAllowed) {
logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
);
}
}
// If site went offline, add it to our tracking set
if (!newOnlineStatus && site.pubKey) {
offlineSites.add(site.pubKey);
}
} catch (error) {
logger.error(
`Failed to update offline status for site ${site.siteId}:`,
error
);
// Continue with other sites
}
}
}
}
}

View File

@@ -112,7 +112,7 @@ export async function updateHolePunch(
destinations: destinations
});
} catch (error) {
// logger.error(error); // FIX THIS
logger.error(error);
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
@@ -262,7 +262,7 @@ export async function updateAndGenerateEndpointDestinations(
if (site.subnet && site.listenPort) {
destinations.push({
destinationIP: site.subnet.split("/")[0],
destinationPort: site.listenPort
destinationPort: site.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
});
}
}
@@ -339,10 +339,10 @@ export async function updateAndGenerateEndpointDestinations(
handleSiteEndpointChange(newt.siteId, updatedSite.endpoint!);
}
if (!updatedSite || !updatedSite.subnet) {
logger.warn(`Site not found: ${newt.siteId}`);
throw new Error("Site not found");
}
// if (!updatedSite || !updatedSite.subnet) {
// logger.warn(`Site not found: ${newt.siteId}`);
// throw new Error("Site not found");
// }
// Find all clients that connect to this site
// const sitesClientPairs = await db

View File

@@ -27,7 +27,7 @@ registry.registerPath({
method: "put",
path: "/idp/{idpId}/org/{orgId}",
description: "Create an IDP policy for an existing IDP on an organization.",
tags: [OpenAPITags.Idp],
tags: [OpenAPITags.GlobalIdp],
request: {
params: paramsSchema,
body: {

View File

@@ -37,7 +37,7 @@ registry.registerPath({
method: "put",
path: "/idp/oidc",
description: "Create an OIDC IdP.",
tags: [OpenAPITags.Idp],
tags: [OpenAPITags.GlobalIdp],
request: {
body: {
content: {

View File

@@ -21,7 +21,7 @@ registry.registerPath({
method: "delete",
path: "/idp/{idpId}",
description: "Delete IDP.",
tags: [OpenAPITags.Idp],
tags: [OpenAPITags.GlobalIdp],
request: {
params: paramsSchema
},

View File

@@ -19,7 +19,7 @@ registry.registerPath({
method: "delete",
path: "/idp/{idpId}/org/{orgId}",
description: "Create an OIDC IdP for an organization.",
tags: [OpenAPITags.Idp],
tags: [OpenAPITags.GlobalIdp],
request: {
params: paramsSchema
},

View File

@@ -34,7 +34,7 @@ registry.registerPath({
method: "get",
path: "/idp/{idpId}",
description: "Get an IDP by its IDP ID.",
tags: [OpenAPITags.Idp],
tags: [OpenAPITags.GlobalIdp],
request: {
params: paramsSchema
},

View File

@@ -48,7 +48,7 @@ registry.registerPath({
method: "get",
path: "/idp/{idpId}/org",
description: "List all org policies on an IDP.",
tags: [OpenAPITags.Idp],
tags: [OpenAPITags.GlobalIdp],
request: {
params: paramsSchema,
query: querySchema

View File

@@ -58,7 +58,7 @@ registry.registerPath({
method: "get",
path: "/idp",
description: "List all IDP in the system.",
tags: [OpenAPITags.Idp],
tags: [OpenAPITags.GlobalIdp],
request: {
query: querySchema
},

View File

@@ -26,7 +26,7 @@ registry.registerPath({
method: "post",
path: "/idp/{idpId}/org/{orgId}",
description: "Update an IDP org policy.",
tags: [OpenAPITags.Idp],
tags: [OpenAPITags.GlobalIdp],
request: {
params: paramsSchema,
body: {

View File

@@ -42,7 +42,7 @@ registry.registerPath({
method: "post",
path: "/idp/{idpId}/oidc",
description: "Update an OIDC IdP.",
tags: [OpenAPITags.Idp],
tags: [OpenAPITags.GlobalIdp],
request: {
params: paramsSchema,
body: {

View File

@@ -135,6 +135,13 @@ authenticated.post(
logActionAudit(ActionsEnum.updateSite),
site.updateSite
);
authenticated.post(
"/org/:orgId/reset-bandwidth",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.resetSiteBandwidth),
logActionAudit(ActionsEnum.resetSiteBandwidth),
org.resetOrgBandwidth
);
authenticated.delete(
"/site/:siteId",
@@ -309,6 +316,14 @@ authenticated.post(
siteResource.removeClientFromSiteResource
);
authenticated.post(
"/client/:clientId/site-resources",
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
logActionAudit(ActionsEnum.setResourceUsers),
siteResource.batchAddClientToSiteResources
);
authenticated.put(
"/org/:orgId/resource",
verifyApiKeyOrgAccess,

View File

@@ -15,6 +15,7 @@ import { initPeerAddHandshake, updatePeer } from "../olm/peers";
import { eq, and } from "drizzle-orm";
import config from "@server/lib/config";
import {
formatEndpoint,
generateSubnetProxyTargetV2,
SubnetProxyTargetV2
} from "@server/lib/ip";
@@ -83,40 +84,42 @@ export async function buildClientConfigurationForNewtClient(
// )
// );
// update the peer info on the olm
// if the peer has not been added yet this will be a no-op
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: site.endpoint!,
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
publicKey: site.publicKey!,
serverIP: site.address,
serverPort: site.listenPort
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// ),
// aliases: generateAliasConfig(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// )
});
if (!client.clientSitesAssociationsCache.isJitMode) { // if we are adding sites through jit then dont add the site to the olm
// update the peer info on the olm
// if the peer has not been added yet this will be a no-op
await updatePeer(client.clients.clientId, {
siteId: site.siteId,
endpoint: site.endpoint!,
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
publicKey: site.publicKey!,
serverIP: site.address,
serverPort: site.listenPort
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// ),
// aliases: generateAliasConfig(
// allSiteResources.map(
// ({ siteResources }) => siteResources
// )
// )
});
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clients.clientId,
{
siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clients.clientId,
{
siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
}
}
);
);
}
return {
publicKey: client.clients.pubKey!,
@@ -204,7 +207,8 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
hcTimeout: targetHealthCheck.hcTimeout,
hcHeaders: targetHealthCheck.hcHeaders,
hcMethod: targetHealthCheck.hcMethod,
hcTlsServerName: targetHealthCheck.hcTlsServerName
hcTlsServerName: targetHealthCheck.hcTlsServerName,
hcStatus: targetHealthCheck.hcStatus
})
.from(targets)
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
@@ -221,8 +225,8 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
return acc;
}
// Format target into string
const formattedTarget = `${target.internalPort}:${target.ip}:${target.port}`;
// Format target into string (handles IPv6 bracketing)
const formattedTarget = `${target.internalPort}:${formatEndpoint(target.ip, target.port)}`;
// Add to the appropriate protocol array
if (target.protocol === "tcp") {
@@ -245,9 +249,9 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
!target.hcInterval ||
!target.hcMethod
) {
logger.debug(
`Skipping target ${target.targetId} due to missing health check fields`
);
// logger.debug(
// `Skipping adding target health check ${target.targetId} due to missing health check fields`
// );
return null; // Skip targets with missing health check fields
}
@@ -277,7 +281,8 @@ export async function buildTargetConfigurationForNewtClient(siteId: number) {
hcTimeout: target.hcTimeout, // in seconds
hcHeaders: hcHeadersSend,
hcMethod: target.hcMethod,
hcTlsServerName: target.hcTlsServerName
hcTlsServerName: target.hcTlsServerName,
hcStatus: target.hcStatus
};
});

View File

@@ -7,6 +7,7 @@ import { eq } from "drizzle-orm";
import { sendToExitNode } from "#dynamic/lib/exitNodes";
import { buildClientConfigurationForNewtClient } from "./buildConfiguration";
import { convertTargetsIfNessicary } from "../client/targets";
import { canCompress } from "@server/lib/clientVersionChecks";
const inputSchema = z.object({
publicKey: z.string(),
@@ -105,11 +106,11 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
const payload = {
oldDestination: {
destinationIP: existingSite.subnet?.split("/")[0],
destinationPort: existingSite.listenPort
destinationPort: existingSite.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
},
newDestination: {
destinationIP: site.subnet?.split("/")[0],
destinationPort: site.listenPort
destinationPort: site.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
}
};
@@ -138,6 +139,9 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
targets: targetsToSend
}
},
options: {
compress: canCompress(newt.version, "newt")
},
broadcast: false,
excludeSender: false
};

View File

@@ -0,0 +1,36 @@
import { MessageHandler } from "@server/routers/ws";
import { db, Newt, sites } from "@server/db";
import { eq } from "drizzle-orm";
import logger from "@server/logger";
/**
* Handles disconnecting messages from sites to show disconnected in the ui
*/
export const handleNewtDisconnectingMessage: MessageHandler = async (
context
) => {
const { message, client: c, sendToClient } = context;
const newt = c as Newt;
if (!newt) {
logger.warn("Newt not found");
return;
}
if (!newt.siteId) {
logger.warn("Newt has no client ID!");
return;
}
try {
// Update the client's last ping timestamp
await db
.update(sites)
.set({
online: false
})
.where(eq(sites.siteId, newt.siteId));
} catch (error) {
logger.error("Error handling disconnecting message", { error });
}
};

View File

@@ -1,105 +1,107 @@
import { db, sites } from "@server/db";
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
import { db, newts, sites } from "@server/db";
import { hasActiveConnections, getClientConfigVersion } from "#dynamic/routers/ws";
import { MessageHandler } from "@server/routers/ws";
import { clients, Newt } from "@server/db";
import { Newt } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm";
import logger from "@server/logger";
import { validateSessionToken } from "@server/auth/sessions/app";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { sendTerminateClient } from "../client/terminate";
import { encodeHexLowerCase } from "@oslojs/encoding";
import { sha256 } from "@oslojs/crypto/sha2";
import { sendNewtSyncMessage } from "./sync";
// Track if the offline checker interval is running
// let offlineCheckerInterval: NodeJS.Timeout | null = null;
// const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
// const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
let offlineCheckerInterval: NodeJS.Timeout | null = null;
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
/**
* Starts the background interval that checks for clients that haven't pinged recently
* and marks them as offline
* Starts the background interval that checks for newt sites that haven't
* pinged recently and marks them as offline. For backward compatibility,
* a site is only marked offline when there is no active WebSocket connection
* either — so older newt versions that don't send pings but remain connected
* continue to be treated as online.
*/
// export const startNewtOfflineChecker = (): void => {
// if (offlineCheckerInterval) {
// return; // Already running
// }
export const startNewtOfflineChecker = (): void => {
if (offlineCheckerInterval) {
return; // Already running
}
// offlineCheckerInterval = setInterval(async () => {
// try {
// const twoMinutesAgo = Math.floor(
// (Date.now() - OFFLINE_THRESHOLD_MS) / 1000
// );
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
);
// // TODO: WE NEED TO MAKE SURE THIS WORKS WITH DISTRIBUTED NODES ALL DOING THE SAME THING
// Find all online newt-type sites that haven't pinged recently
// (or have never pinged at all). Join newts to obtain the newtId
// needed for the WebSocket connection check.
const staleSites = await db
.select({
siteId: sites.siteId,
newtId: newts.newtId,
lastPing: sites.lastPing
})
.from(sites)
.innerJoin(newts, eq(newts.siteId, sites.siteId))
.where(
and(
eq(sites.online, true),
eq(sites.type, "newt"),
or(
lt(sites.lastPing, twoMinutesAgo),
isNull(sites.lastPing)
)
)
);
// // Find clients that haven't pinged in the last 2 minutes and mark them as offline
// const offlineClients = await db
// .update(clients)
// .set({ online: false })
// .where(
// and(
// eq(clients.online, true),
// or(
// lt(clients.lastPing, twoMinutesAgo),
// isNull(clients.lastPing)
// )
// )
// )
// .returning();
for (const staleSite of staleSites) {
// Backward-compatibility check: if the newt still has an
// active WebSocket connection (older clients that don't send
// pings), keep the site online.
const isConnected = await hasActiveConnections(staleSite.newtId);
if (isConnected) {
logger.debug(
`Newt ${staleSite.newtId} has not pinged recently but is still connected via WebSocket — keeping site ${staleSite.siteId} online`
);
continue;
}
// for (const offlineClient of offlineClients) {
// logger.info(
// `Kicking offline newt client ${offlineClient.clientId} due to inactivity`
// );
logger.info(
`Marking site ${staleSite.siteId} offline: newt ${staleSite.newtId} has no recent ping and no active WebSocket connection`
);
// if (!offlineClient.newtId) {
// logger.warn(
// `Offline client ${offlineClient.clientId} has no newtId, cannot disconnect`
// );
// continue;
// }
await db
.update(sites)
.set({ online: false })
.where(eq(sites.siteId, staleSite.siteId));
}
} catch (error) {
logger.error("Error in newt offline checker interval", { error });
}
}, OFFLINE_CHECK_INTERVAL);
// // Send a disconnect message to the client if connected
// try {
// await sendTerminateClient(
// offlineClient.clientId,
// offlineClient.newtId
// ); // terminate first
// // wait a moment to ensure the message is sent
// await new Promise((resolve) => setTimeout(resolve, 1000));
// await disconnectClient(offlineClient.newtId);
// } catch (error) {
// logger.error(
// `Error sending disconnect to offline newt ${offlineClient.clientId}`,
// { error }
// );
// }
// }
// } catch (error) {
// logger.error("Error in offline checker interval", { error });
// }
// }, OFFLINE_CHECK_INTERVAL);
// logger.debug("Started offline checker interval");
// };
logger.debug("Started newt offline checker interval");
};
/**
* Stops the background interval that checks for offline clients
* Stops the background interval that checks for offline newt sites.
*/
// export const stopNewtOfflineChecker = (): void => {
// if (offlineCheckerInterval) {
// clearInterval(offlineCheckerInterval);
// offlineCheckerInterval = null;
// logger.info("Stopped offline checker interval");
// }
// };
export const stopNewtOfflineChecker = (): void => {
if (offlineCheckerInterval) {
clearInterval(offlineCheckerInterval);
offlineCheckerInterval = null;
logger.info("Stopped newt offline checker interval");
}
};
/**
* Handles ping messages from clients and responds with pong
* Handles ping messages from newt clients.
*
* On each ping:
* - Marks the associated site as online.
* - Records the current timestamp as the newt's last-ping time.
* - Triggers a config sync if the newt is running an outdated config version.
* - Responds with a pong message.
*/
export const handleNewtPingMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const { message, client: c } = context;
const newt = c as Newt;
if (!newt) {
@@ -112,15 +114,31 @@ export const handleNewtPingMessage: MessageHandler = async (context) => {
return;
}
// get the version
try {
// Mark the site as online and record the ping timestamp.
await db
.update(sites)
.set({
online: true,
lastPing: Math.floor(Date.now() / 1000)
})
.where(eq(sites.siteId, newt.siteId));
} catch (error) {
logger.error("Error updating online state on newt ping", { error });
}
// Check config version and sync if stale.
const configVersion = await getClientConfigVersion(newt.newtId);
if (message.configVersion && configVersion != null && configVersion != message.configVersion) {
if (
message.configVersion != null &&
configVersion != null &&
configVersion !== message.configVersion
) {
logger.warn(
`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
);
// get the site
const [site] = await db
.select()
.from(sites)
@@ -137,19 +155,6 @@ export const handleNewtPingMessage: MessageHandler = async (context) => {
await sendNewtSyncMessage(newt, site);
}
// try {
// // Update the client's last ping timestamp
// await db
// .update(clients)
// .set({
// lastPing: Math.floor(Date.now() / 1000),
// online: true
// })
// .where(eq(clients.clientId, newt.clientId));
// } catch (error) {
// logger.error("Error handling ping message", { error });
// }
return {
message: {
type: "pong",

View File

@@ -5,9 +5,7 @@ import { eq } from "drizzle-orm";
import { addPeer, deletePeer } from "../gerbil/peers";
import logger from "@server/logger";
import config from "@server/lib/config";
import {
findNextAvailableCidr,
} from "@server/lib/ip";
import { findNextAvailableCidr } from "@server/lib/ip";
import {
selectBestExitNode,
verifyExitNodeOrgAccess
@@ -15,6 +13,7 @@ import {
import { fetchContainers } from "./dockerSocket";
import { lockManager } from "#dynamic/lib/lock";
import { buildTargetConfigurationForNewtClient } from "./buildConfiguration";
import { canCompress } from "@server/lib/clientVersionChecks";
export type ExitNodePingResult = {
exitNodeId: number;
@@ -215,6 +214,9 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
healthCheckTargets: validHealthCheckTargets
}
},
options: {
compress: canCompress(newt.version, "newt")
},
broadcast: false, // Send to all clients
excludeSender: false // Include sender in broadcast
};

View File

@@ -10,10 +10,21 @@ interface PeerBandwidth {
bytesOut: number;
}
interface BandwidthAccumulator {
bytesIn: number;
bytesOut: number;
}
// Retry configuration for deadlock handling
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50;
// How often to flush accumulated bandwidth data to the database
const FLUSH_INTERVAL_MS = 120_000; // 120 seconds
// In-memory accumulator: publicKey -> { bytesIn, bytesOut }
let accumulator = new Map<string, BandwidthAccumulator>();
/**
* Check if an error is a deadlock error
*/
@@ -53,6 +64,90 @@ async function withDeadlockRetry<T>(
}
}
/**
* Flush all accumulated bandwidth data to the database.
*
* Swaps out the accumulator before writing so that any bandwidth messages
* received during the flush are captured in the new accumulator rather than
* being lost or causing contention. Entries that fail to write are re-queued
* back into the accumulator so they will be retried on the next flush.
*
* This function is exported so that the application's graceful-shutdown
* cleanup handler can call it before the process exits.
*/
export async function flushBandwidthToDb(): Promise<void> {
if (accumulator.size === 0) {
return;
}
// Atomically swap out the accumulator so new data keeps flowing in
// while we write the snapshot to the database.
const snapshot = accumulator;
accumulator = new Map<string, BandwidthAccumulator>();
const currentTime = new Date().toISOString();
// Sort by publicKey for consistent lock ordering across concurrent
// writers — this is the same deadlock-prevention strategy used in the
// original per-message implementation.
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
a.localeCompare(b)
);
logger.debug(
`Flushing accumulated bandwidth data for ${sortedEntries.length} client(s) to the database`
);
for (const [publicKey, { bytesIn, bytesOut }] of sortedEntries) {
try {
await withDeadlockRetry(async () => {
// Use atomic SQL increment to avoid the SELECT-then-UPDATE
// anti-pattern and the races it would introduce.
await db
.update(clients)
.set({
// Note: bytesIn from peer goes to megabytesOut (data
// sent to client) and bytesOut from peer goes to
// megabytesIn (data received from client).
megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime
})
.where(eq(clients.pubKey, publicKey));
}, `flush bandwidth for client ${publicKey}`);
} catch (error) {
logger.error(
`Failed to flush bandwidth for client ${publicKey}:`,
error
);
// Re-queue the failed entry so it is retried on the next flush
// rather than silently dropped.
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
} else {
accumulator.set(publicKey, { bytesIn, bytesOut });
}
}
}
}
const flushTimer = setInterval(async () => {
try {
await flushBandwidthToDb();
} catch (error) {
logger.error("Unexpected error during periodic bandwidth flush:", error);
}
}, FLUSH_INTERVAL_MS);
// Calling unref() means this timer will not keep the Node.js event loop alive
// on its own — the process can still exit normally when there is no other work
// left. The graceful-shutdown path (see server/cleanup.ts) will call
// flushBandwidthToDb() explicitly before process.exit(), so no data is lost.
flushTimer.unref();
export const handleReceiveBandwidthMessage: MessageHandler = async (
context
) => {
@@ -69,40 +164,21 @@ export const handleReceiveBandwidthMessage: MessageHandler = async (
throw new Error("Invalid bandwidth data");
}
// Sort bandwidth data by publicKey to ensure consistent lock ordering across all instances
// This is critical for preventing deadlocks when multiple instances update the same clients
const sortedBandwidthData = [...bandwidthData].sort((a, b) =>
a.publicKey.localeCompare(b.publicKey)
);
// Accumulate the incoming data in memory; the periodic timer (and the
// shutdown hook) will take care of writing it to the database.
for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
// Skip peers that haven't transferred any data — writing zeros to the
// database would be a no-op anyway.
if (bytesIn <= 0 && bytesOut <= 0) {
continue;
}
const currentTime = new Date().toISOString();
// Update each client individually with retry logic
// This reduces transaction scope and allows retries per-client
for (const peer of sortedBandwidthData) {
const { publicKey, bytesIn, bytesOut } = peer;
try {
await withDeadlockRetry(async () => {
// Use atomic SQL increment to avoid SELECT then UPDATE pattern
// This eliminates the need to read the current value first
await db
.update(clients)
.set({
// Note: bytesIn from peer goes to megabytesOut (data sent to client)
// and bytesOut from peer goes to megabytesIn (data received from client)
megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime
})
.where(eq(clients.pubKey, publicKey));
}, `update client bandwidth ${publicKey}`);
} catch (error) {
logger.error(
`Failed to update bandwidth for client ${publicKey}:`,
error
);
// Continue with other clients even if one fails
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
} else {
accumulator.set(publicKey, { bytesIn, bytesOut });
}
}
};

View File

@@ -7,3 +7,4 @@ export * from "./handleSocketMessages";
export * from "./handleNewtPingRequestMessage";
export * from "./handleApplyBlueprintMessage";
export * from "./handleNewtPingMessage";
export * from "./handleNewtDisconnectingMessage";

View File

@@ -6,6 +6,7 @@ import {
buildClientConfigurationForNewtClient,
buildTargetConfigurationForNewtClient
} from "./buildConfiguration";
import { canCompress } from "@server/lib/clientVersionChecks";
export async function sendNewtSyncMessage(newt: Newt, site: Site) {
const { tcpTargets, udpTargets, validHealthCheckTargets } =
@@ -24,18 +25,24 @@ export async function sendNewtSyncMessage(newt: Newt, site: Site) {
exitNode
);
await sendToClient(newt.newtId, {
type: "newt/sync",
data: {
proxyTargets: {
udp: udpTargets,
tcp: tcpTargets
},
healthCheckTargets: validHealthCheckTargets,
peers: peers,
clientTargets: targets
await sendToClient(
newt.newtId,
{
type: "newt/sync",
data: {
proxyTargets: {
udp: udpTargets,
tcp: tcpTargets
},
healthCheckTargets: validHealthCheckTargets,
peers: peers,
clientTargets: targets
}
},
{
compress: canCompress(newt.version, "newt")
}
}).catch((error) => {
).catch((error) => {
logger.warn(`Error sending newt sync message:`, error);
});
}

View File

@@ -2,13 +2,14 @@ import { Target, TargetHealthCheck, db, targetHealthCheck } from "@server/db";
import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger";
import { eq, inArray } from "drizzle-orm";
import { canCompress } from "@server/lib/clientVersionChecks";
export async function addTargets(
newtId: string,
targets: Target[],
healthCheckData: TargetHealthCheck[],
protocol: string,
port: number | null = null
version?: string | null
) {
//create a list of udp and tcp targets
const payloadTargets = targets.map((target) => {
@@ -22,7 +23,7 @@ export async function addTargets(
data: {
targets: payloadTargets
}
}, { incrementConfigVersion: true });
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
// Create a map for quick lookup
const healthCheckMap = new Map<number, TargetHealthCheck>();
@@ -103,14 +104,14 @@ export async function addTargets(
data: {
targets: validHealthCheckTargets
}
}, { incrementConfigVersion: true });
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
}
export async function removeTargets(
newtId: string,
targets: Target[],
protocol: string,
port: number | null = null
version?: string | null
) {
//create a list of udp and tcp targets
const payloadTargets = targets.map((target) => {
@@ -135,5 +136,5 @@ export async function removeTargets(
data: {
ids: healthCheckTargets
}
}, { incrementConfigVersion: true });
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
}

View File

@@ -1,5 +1,17 @@
import { Client, clientSiteResourcesAssociationsCache, clientSitesAssociationsCache, db, exitNodes, siteResources, sites } from "@server/db";
import { generateAliasConfig, generateRemoteSubnets } from "@server/lib/ip";
import {
Client,
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
exitNodes,
siteResources,
sites
} from "@server/db";
import {
Alias,
generateAliasConfig,
generateRemoteSubnets
} from "@server/lib/ip";
import logger from "@server/logger";
import { and, eq } from "drizzle-orm";
import { addPeer, deletePeer } from "../newt/peers";
@@ -8,9 +20,19 @@ import config from "@server/lib/config";
export async function buildSiteConfigurationForOlmClient(
client: Client,
publicKey: string | null,
relay: boolean
relay: boolean,
jitMode: boolean = false
) {
const siteConfigurations = [];
const siteConfigurations: {
siteId: number;
name?: string
endpoint?: string
publicKey?: string
serverIP?: string | null
serverPort?: number | null
remoteSubnets?: string[];
aliases: Alias[];
}[] = [];
// Get all sites data
const sitesData = await db
@@ -27,6 +49,40 @@ export async function buildSiteConfigurationForOlmClient(
sites: site,
clientSitesAssociationsCache: association
} of sitesData) {
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
if (jitMode) {
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,
// remoteSubnets: generateRemoteSubnets(
// allSiteResources.map(({ siteResources }) => siteResources)
// ),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
});
continue;
}
if (!site.exitNodeId) {
logger.warn(
`Site ${site.siteId} does not have exit node, skipping`
@@ -42,6 +98,13 @@ export async function buildSiteConfigurationForOlmClient(
continue;
}
if (!site.publicKey || site.publicKey == "") { // the site is not ready to accept new peers
logger.warn(
`Site ${site.siteId} has no public key, skipping`
);
continue;
}
// if (site.lastHolePunch && now - site.lastHolePunch > 6 && relay) {
// logger.warn(
// `Site ${site.siteId} last hole punch is too old, skipping`
@@ -103,26 +166,6 @@ export async function buildSiteConfigurationForOlmClient(
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
}
const allSiteResources = await db // only get the site resources that this client has access to
.select()
.from(siteResources)
.innerJoin(
clientSiteResourcesAssociationsCache,
eq(
siteResources.siteResourceId,
clientSiteResourcesAssociationsCache.siteResourceId
)
)
.where(
and(
eq(siteResources.siteId, site.siteId),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
// Add site configuration to the array
siteConfigurations.push({
siteId: site.siteId,

View File

@@ -6,7 +6,7 @@ import logger from "@server/logger";
/**
* Handles disconnecting messages from clients to show disconnected in the ui
*/
export const handleOlmDisconnecingMessage: MessageHandler = async (context) => {
export const handleOlmDisconnectingMessage: MessageHandler = async (context) => {
const { message, client: c, sendToClient } = context;
const olm = c as Olm;

View File

@@ -17,6 +17,9 @@ import { getUserDeviceName } from "@server/db/names";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { OlmErrorCodes, sendOlmError } from "./error";
import { handleFingerprintInsertion } from "./fingerprintingUtils";
import { Alias } from "@server/lib/ip";
import { build } from "@server/build";
import { canCompress } from "@server/lib/clientVersionChecks";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!");
@@ -207,6 +210,32 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
}
}
// Get all sites data
const sitesCountResult = await db
.select({ count: count() })
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
let jitMode = false;
if (sitesCount > 250 && build == "saas") {
// THIS IS THE MAX ON THE BUSINESS TIER
// we have too many sites
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
logger.info("Too many sites (%d), dropping into JIT mode", sitesCount);
jitMode = true;
}
logger.debug(
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
);
@@ -233,28 +262,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
await db
.update(clientSitesAssociationsCache)
.set({
isRelayed: relay == true
isRelayed: relay == true,
isJitMode: jitMode
})
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
}
// Get all sites data
const sitesCountResult = await db
.select({ count: count() })
.from(sites)
.innerJoin(
clientSitesAssociationsCache,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract the count value from the result array
const sitesCount =
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
// Prepare an array to store site configurations
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
// this prevents us from accepting a register from an olm that has not hole punched yet.
// the olm will pump the register so we can keep checking
// TODO: I still think there is a better way to do this rather than locking it out here but ???
@@ -265,19 +278,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
return;
}
// NOTE: its important that the client here is the old client and the public key is the new key
// NOTE: its important that the client here is the old client and the public key is the new key
const siteConfigurations = await buildSiteConfigurationForOlmClient(
client,
publicKey,
relay
relay,
jitMode
);
// REMOVED THIS SO IT CREATES THE INTERFACE AND JUST WAITS FOR THE SITES
// if (siteConfigurations.length === 0) {
// logger.warn("No valid site configurations found");
// return;
// }
// Return connect message with all site configurations
return {
message: {
@@ -288,6 +296,9 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
utilitySubnet: org.utilitySubnet
}
},
options: {
compress: canCompress(olm.version, "olm")
},
broadcast: false,
excludeSender: false
};

View File

@@ -18,7 +18,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
}
if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
logger.warn("Olm has no client!");
return;
}
@@ -41,7 +41,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
return;
}
const { siteId } = message.data;
const { siteId, chainId } = message.data;
// Get the site
const [site] = await db
@@ -90,7 +90,8 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
data: {
siteId: siteId,
relayEndpoint: exitNode.endpoint,
relayPort: config.getRawConfig().gerbil.clients_start_port
relayPort: config.getRawConfig().gerbil.clients_start_port,
chainId
}
},
broadcast: false,

View File

@@ -0,0 +1,241 @@
import {
clientSiteResourcesAssociationsCache,
clientSitesAssociationsCache,
db,
exitNodes,
Site,
siteResources
} from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, Olm, sites } from "@server/db";
import { and, eq, or } from "drizzle-orm";
import logger from "@server/logger";
import { initPeerAddHandshake } from "./peers";
export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
context
) => {
logger.info("Handling register olm message!");
const { message, client: c, sendToClient } = context;
const olm = c as Olm;
if (!olm) {
logger.warn("Olm not found");
return;
}
if (!olm.clientId) {
logger.warn("Olm has no client!"); // TODO: Maybe we create the site here?
return;
}
const clientId = olm.clientId;
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
logger.warn("Client not found");
return;
}
const { siteId, resourceId, chainId } = message.data;
let site: Site | null = null;
if (siteId) {
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (siteRes) {
site = siteRes;
}
}
if (resourceId && !site) {
const resources = await db
.select()
.from(siteResources)
.where(
and(
or(
eq(siteResources.niceId, resourceId),
eq(siteResources.alias, resourceId)
),
eq(siteResources.orgId, client.orgId)
)
);
if (!resources || resources.length === 0) {
logger.error(`handleOlmServerPeerAddMessage: Resource not found`);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
if (resources.length > 1) {
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
logger.error(
`handleOlmServerPeerAddMessage: Multiple resources found matching the criteria`
);
return;
}
const resource = resources[0];
const currentResourceAssociationCaches = await db
.select()
.from(clientSiteResourcesAssociationsCache)
.where(
and(
eq(
clientSiteResourcesAssociationsCache.siteResourceId,
resource.siteResourceId
),
eq(
clientSiteResourcesAssociationsCache.clientId,
client.clientId
)
)
);
if (currentResourceAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
const siteIdFromResource = resource.siteId;
// get the site
const [siteRes] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteIdFromResource));
if (!siteRes) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site} not found`
);
return;
}
site = siteRes;
}
if (!site) {
logger.error(`handleOlmServerPeerAddMessage: Site not found`);
return;
}
// check if the client can access this site using the cache
const currentSiteAssociationCaches = await db
.select()
.from(clientSitesAssociationsCache)
.where(
and(
eq(clientSitesAssociationsCache.clientId, client.clientId),
eq(clientSitesAssociationsCache.siteId, site.siteId)
)
);
if (currentSiteAssociationCaches.length === 0) {
logger.error(
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
if (!site.exitNodeId) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
);
// cancel the request from the olm side to not keep doing this
await sendToClient(
olm.olmId,
{
type: "olm/wg/peer/chain/cancel",
data: {
chainId
}
},
{ incrementConfigVersion: false }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
return;
}
// get the exit node from the side
const [exitNode] = await db
.select()
.from(exitNodes)
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
if (!exitNode) {
logger.error(
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
);
return;
}
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
// if it has already been added this will be a no-op
await initPeerAddHandshake(
// this will kick off the add peer process for the client
client.clientId,
{
siteId: site.siteId,
exitNode: {
publicKey: exitNode.publicKey,
endpoint: exitNode.endpoint
}
},
olm.olmId,
chainId
);
return;
};

View File

@@ -54,7 +54,7 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
return;
}
const { siteId } = message.data;
const { siteId, chainId } = message.data;
// get the site
const [site] = await db
@@ -179,7 +179,8 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
),
aliases: generateAliasConfig(
allSiteResources.map(({ siteResources }) => siteResources)
)
),
chainId: chainId,
}
},
broadcast: false,

View File

@@ -17,7 +17,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
}
if (!olm.clientId) {
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
logger.warn("Olm has no client!");
return;
}
@@ -40,7 +40,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
return;
}
const { siteId } = message.data;
const { siteId, chainId } = message.data;
// Get the site
const [site] = await db
@@ -87,7 +87,8 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
type: "olm/wg/peer/unrelay",
data: {
siteId: siteId,
endpoint: site.endpoint
endpoint: site.endpoint,
chainId
}
},
broadcast: false,

View File

@@ -11,3 +11,4 @@ export * from "./handleOlmServerPeerAddMessage";
export * from "./handleOlmUnRelayMessage";
export * from "./recoverOlmWithFingerprint";
export * from "./handleOlmDisconnectingMessage";
export * from "./handleOlmServerInitAddPeerHandshake";

View File

@@ -1,8 +1,9 @@
import { sendToClient } from "#dynamic/routers/ws";
import { db, olms } from "@server/db";
import { clientSitesAssociationsCache, db, olms } from "@server/db";
import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config";
import logger from "@server/logger";
import { eq } from "drizzle-orm";
import { and, eq } from "drizzle-orm";
import { Alias } from "yaml";
export async function addPeer(
@@ -18,7 +19,8 @@ export async function addPeer(
remoteSubnets: string[] | null; // optional, comma-separated list of subnets that this site can access
aliases: Alias[];
},
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -30,6 +32,7 @@ export async function addPeer(
return; // ignore this because an olm might not be associated with the client anymore
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(
@@ -48,7 +51,7 @@ export async function addPeer(
aliases: peer.aliases
}
},
{ incrementConfigVersion: true }
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
@@ -60,7 +63,8 @@ export async function deletePeer(
clientId: number,
siteId: number,
publicKey: string,
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -72,6 +76,7 @@ export async function deletePeer(
return;
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(
@@ -83,7 +88,7 @@ export async function deletePeer(
siteId: siteId
}
},
{ incrementConfigVersion: true }
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
@@ -103,7 +108,8 @@ export async function updatePeer(
remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that
aliases?: Alias[] | null;
},
olmId?: string
olmId?: string,
version?: string | null
) {
if (!olmId) {
const [olm] = await db
@@ -115,6 +121,7 @@ export async function updatePeer(
return;
}
olmId = olm.olmId;
version = olm.version;
}
await sendToClient(
@@ -132,7 +139,7 @@ export async function updatePeer(
aliases: peer.aliases
}
},
{ incrementConfigVersion: true }
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
).catch((error) => {
logger.warn(`Error sending message:`, error);
});
@@ -149,7 +156,8 @@ export async function initPeerAddHandshake(
endpoint: string;
};
},
olmId?: string
olmId?: string,
chainId?: string
) {
if (!olmId) {
const [olm] = await db
@@ -173,7 +181,8 @@ export async function initPeerAddHandshake(
publicKey: peer.exitNode.publicKey,
relayPort: config.getRawConfig().gerbil.clients_start_port,
endpoint: peer.exitNode.endpoint
}
},
chainId
}
},
{ incrementConfigVersion: true }
@@ -181,6 +190,17 @@ export async function initPeerAddHandshake(
logger.warn(`Error sending message:`, error);
});
// update the clientSiteAssociationsCache to make the isJitMode flag false so that JIT mode is disabled for this site if it restarts or something after the connection
await db
.update(clientSitesAssociationsCache)
.set({ isJitMode: false })
.where(
and(
eq(clientSitesAssociationsCache.clientId, clientId),
eq(clientSitesAssociationsCache.siteId, peer.siteId)
)
);
logger.info(
`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`
);

View File

@@ -1,9 +1,17 @@
import { Client, db, exitNodes, Olm, sites, clientSitesAssociationsCache } from "@server/db";
import {
Client,
db,
exitNodes,
Olm,
sites,
clientSitesAssociationsCache
} from "@server/db";
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
import { sendToClient } from "#dynamic/routers/ws";
import logger from "@server/logger";
import { eq, inArray } from "drizzle-orm";
import config from "@server/lib/config";
import { canCompress } from "@server/lib/clientVersionChecks";
export async function sendOlmSyncMessage(olm: Olm, client: Client) {
// NOTE: WE ARE HARDCODING THE RELAY PARAMETER TO FALSE HERE BUT IN THE REGISTER MESSAGE ITS DEFINED BY THE CLIENT
@@ -17,10 +25,7 @@ export async function sendOlmSyncMessage(olm: Olm, client: Client) {
const clientSites = await db
.select()
.from(clientSitesAssociationsCache)
.innerJoin(
sites,
eq(sites.siteId, clientSitesAssociationsCache.siteId)
)
.innerJoin(sites, eq(sites.siteId, clientSitesAssociationsCache.siteId))
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
// Extract unique exit node IDs
@@ -68,13 +73,20 @@ export async function sendOlmSyncMessage(olm: Olm, client: Client) {
logger.debug("sendOlmSyncMessage: sending sync message");
await sendToClient(olm.olmId, {
type: "olm/sync",
data: {
sites: siteConfigurations,
exitNodes: exitNodesData
await sendToClient(
olm.olmId,
{
type: "olm/sync",
data: {
sites: siteConfigurations,
exitNodes: exitNodesData
}
},
{
compress: canCompress(olm.version, "olm")
}
}).catch((error) => {
).catch((error) => {
logger.warn(`Error sending olm sync message:`, error);
});
}

View File

@@ -8,3 +8,4 @@ export * from "./getOrgOverview";
export * from "./listOrgs";
export * from "./pickOrgDefaults";
export * from "./checkOrgUserAccess";
export * from "./resetOrgBandwidth";

View File

@@ -0,0 +1,83 @@
import { NextFunction, Request, Response } from "express";
import { z } from "zod";
import { db, sites } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const resetOrgBandwidthParamsSchema = z.strictObject({
orgId: z.string()
});
registry.registerPath({
method: "post",
path: "/org/{orgId}/reset-bandwidth",
description: "Reset all sites in selected organization bandwidth counters.",
tags: [OpenAPITags.Org, OpenAPITags.Site],
request: {
params: resetOrgBandwidthParamsSchema
},
responses: {}
});
export async function resetOrgBandwidth(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = resetOrgBandwidthParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const [site] = await db
.select({ siteId: sites.siteId })
.from(sites)
.where(eq(sites.orgId, orgId))
.limit(1);
if (!site) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`No sites found in org ${orgId}`
)
);
}
await db
.update(sites)
.set({
megabytesIn: 0,
megabytesOut: 0
})
.where(eq(sites.orgId, orgId));
return response(res, {
data: {},
success: true,
error: false,
message: "Sites bandwidth reset successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -29,7 +29,7 @@ registry.registerPath({
method: "post",
path: "/resource/{resourceId}/whitelist/add",
description: "Add a single email to the resource whitelist.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: addEmailToResourceWhitelistParamsSchema,
body: {

View File

@@ -29,7 +29,7 @@ registry.registerPath({
method: "post",
path: "/resource/{resourceId}/roles/add",
description: "Add a single role to a resource.",
tags: [OpenAPITags.Resource, OpenAPITags.Role],
tags: [OpenAPITags.PublicResource, OpenAPITags.Role],
request: {
params: addRoleToResourceParamsSchema,
body: {

View File

@@ -29,7 +29,7 @@ registry.registerPath({
method: "post",
path: "/resource/{resourceId}/users/add",
description: "Add a single user to a resource.",
tags: [OpenAPITags.Resource, OpenAPITags.User],
tags: [OpenAPITags.PublicResource, OpenAPITags.User],
request: {
params: addUserToResourceParamsSchema,
body: {

View File

@@ -79,7 +79,7 @@ registry.registerPath({
method: "put",
path: "/org/{orgId}/resource",
description: "Create a resource.",
tags: [OpenAPITags.Org, OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: createResourceParamsSchema,
body: {
@@ -223,6 +223,20 @@ async function createHttpResource(
);
}
// Prevent creating resource with same domain as dashboard
const dashboardUrl = config.getRawConfig().app.dashboard_url;
if (dashboardUrl) {
const dashboardHost = new URL(dashboardUrl).hostname;
if (fullDomain === dashboardHost) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Resource domain cannot be the same as the dashboard domain"
)
);
}
}
if (build != "oss") {
const existingLoginPages = await db
.select()

View File

@@ -31,7 +31,7 @@ registry.registerPath({
method: "put",
path: "/resource/{resourceId}/rule",
description: "Create a resource rule.",
tags: [OpenAPITags.Resource, OpenAPITags.Rule],
tags: [OpenAPITags.PublicResource, OpenAPITags.Rule],
request: {
params: createResourceRuleParamsSchema,
body: {

View File

@@ -22,7 +22,7 @@ registry.registerPath({
method: "delete",
path: "/resource/{resourceId}",
description: "Delete a resource.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: deleteResourceSchema
},

View File

@@ -19,7 +19,7 @@ registry.registerPath({
method: "delete",
path: "/resource/{resourceId}/rule/{ruleId}",
description: "Delete a resource rule.",
tags: [OpenAPITags.Resource, OpenAPITags.Rule],
tags: [OpenAPITags.PublicResource, OpenAPITags.Rule],
request: {
params: deleteResourceRuleSchema
},

View File

@@ -54,7 +54,7 @@ registry.registerPath({
path: "/org/{orgId}/resource/{niceId}",
description:
"Get a resource by orgId and niceId. NiceId is a readable ID for the resource and unique on a per org basis.",
tags: [OpenAPITags.Org, OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: z.object({
orgId: z.string(),
@@ -68,7 +68,7 @@ registry.registerPath({
method: "get",
path: "/resource/{resourceId}",
description: "Get a resource by resourceId.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: z.object({
resourceId: z.number()

View File

@@ -31,7 +31,7 @@ registry.registerPath({
method: "get",
path: "/resource/{resourceId}/whitelist",
description: "Get the whitelist of emails for a specific resource.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: getResourceWhitelistSchema
},

View File

@@ -33,7 +33,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/resources-names",
description: "List all resource names for an organization.",
tags: [OpenAPITags.Org, OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: z.object({
orgId: z.string()

View File

@@ -35,7 +35,7 @@ registry.registerPath({
method: "get",
path: "/resource/{resourceId}/roles",
description: "List all roles for a resource.",
tags: [OpenAPITags.Resource, OpenAPITags.Role],
tags: [OpenAPITags.PublicResource, OpenAPITags.Role],
request: {
params: listResourceRolesSchema
},

View File

@@ -56,7 +56,7 @@ registry.registerPath({
method: "get",
path: "/resource/{resourceId}/rules",
description: "List rules for a resource.",
tags: [OpenAPITags.Resource, OpenAPITags.Rule],
tags: [OpenAPITags.PublicResource, OpenAPITags.Rule],
request: {
params: listResourceRulesParamsSchema,
query: listResourceRulesSchema

View File

@@ -38,7 +38,7 @@ registry.registerPath({
method: "get",
path: "/resource/{resourceId}/users",
description: "List all users for a resource.",
tags: [OpenAPITags.Resource, OpenAPITags.User],
tags: [OpenAPITags.PublicResource, OpenAPITags.User],
request: {
params: listResourceUsersSchema
},

View File

@@ -225,7 +225,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/resources",
description: "List resources for an organization.",
tags: [OpenAPITags.Org, OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: z.object({
orgId: z.string()

View File

@@ -29,7 +29,7 @@ registry.registerPath({
method: "post",
path: "/resource/{resourceId}/whitelist/remove",
description: "Remove a single email from the resource whitelist.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: removeEmailFromResourceWhitelistParamsSchema,
body: {

View File

@@ -29,7 +29,7 @@ registry.registerPath({
method: "post",
path: "/resource/{resourceId}/roles/remove",
description: "Remove a single role from a resource.",
tags: [OpenAPITags.Resource, OpenAPITags.Role],
tags: [OpenAPITags.PublicResource, OpenAPITags.Role],
request: {
params: removeRoleFromResourceParamsSchema,
body: {

View File

@@ -29,7 +29,7 @@ registry.registerPath({
method: "post",
path: "/resource/{resourceId}/users/remove",
description: "Remove a single user from a resource.",
tags: [OpenAPITags.Resource, OpenAPITags.User],
tags: [OpenAPITags.PublicResource, OpenAPITags.User],
request: {
params: removeUserFromResourceParamsSchema,
body: {

View File

@@ -29,7 +29,7 @@ registry.registerPath({
path: "/resource/{resourceId}/header-auth",
description:
"Set or update the header authentication for a resource. If user and password is not provided, it will remove the header authentication.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: setResourceAuthMethodsParamsSchema,
body: {

View File

@@ -25,7 +25,7 @@ registry.registerPath({
path: "/resource/{resourceId}/password",
description:
"Set the password for a resource. Setting the password to null will remove it.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: setResourceAuthMethodsParamsSchema,
body: {

View File

@@ -29,7 +29,7 @@ registry.registerPath({
path: "/resource/{resourceId}/pincode",
description:
"Set the PIN code for a resource. Setting the PIN code to null will remove it.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: setResourceAuthMethodsParamsSchema,
body: {

View File

@@ -23,7 +23,7 @@ registry.registerPath({
path: "/resource/{resourceId}/roles",
description:
"Set roles for a resource. This will replace all existing roles.",
tags: [OpenAPITags.Resource, OpenAPITags.Role],
tags: [OpenAPITags.PublicResource, OpenAPITags.Role],
request: {
params: setResourceRolesParamsSchema,
body: {

View File

@@ -23,7 +23,7 @@ registry.registerPath({
path: "/resource/{resourceId}/users",
description:
"Set users for a resource. This will replace all existing users.",
tags: [OpenAPITags.Resource, OpenAPITags.User],
tags: [OpenAPITags.PublicResource, OpenAPITags.User],
request: {
params: setUserResourcesParamsSchema,
body: {

View File

@@ -32,7 +32,7 @@ registry.registerPath({
path: "/resource/{resourceId}/whitelist",
description:
"Set email whitelist for a resource. This will replace all existing emails.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: setResourceWhitelistParamsSchema,
body: {

View File

@@ -101,6 +101,49 @@ const updateHttpResourceBodySchema = z
{
error: "Invalid custom Host Header value. Use domain name format, or save empty to unset custom Host Header."
}
)
.refine(
(data) => {
if (data.headers) {
// HTTP header names must be valid token characters (RFC 7230)
const validHeaderName = /^[a-zA-Z0-9!#$%&'*+\-.^_`|~]+$/;
return data.headers.every((h) => validHeaderName.test(h.name));
}
return true;
},
{
error: "Header names may only contain valid HTTP token characters (letters, digits, and !#$%&'*+-.^_`|~)."
}
)
.refine(
(data) => {
if (data.headers) {
// HTTP header values must be visible ASCII or horizontal whitespace, no control chars (RFC 7230)
const validHeaderValue = /^[\t\x20-\x7E]*$/;
return data.headers.every((h) => validHeaderValue.test(h.value));
}
return true;
},
{
error: "Header values may only contain printable ASCII characters and horizontal whitespace."
}
)
.refine(
(data) => {
if (data.headers) {
// Reject Traefik template syntax {{word}} in names or values
const templatePattern = /\{\{[^}]+\}\}/;
return data.headers.every(
(h) =>
!templatePattern.test(h.name) &&
!templatePattern.test(h.value)
);
}
return true;
},
{
error: "Header names and values must not contain template expressions such as {{value}}."
}
);
export type UpdateResourceResponse = Resource;
@@ -136,7 +179,7 @@ registry.registerPath({
method: "post",
path: "/resource/{resourceId}",
description: "Update a resource.",
tags: [OpenAPITags.Resource],
tags: [OpenAPITags.PublicResource],
request: {
params: updateResourceParamsSchema,
body: {
@@ -310,6 +353,20 @@ async function updateHttpResource(
);
}
// Prevent updating resource with same domain as dashboard
const dashboardUrl = config.getRawConfig().app.dashboard_url;
if (dashboardUrl) {
const dashboardHost = new URL(dashboardUrl).hostname;
if (fullDomain === dashboardHost) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Resource domain cannot be the same as the dashboard domain"
)
);
}
}
if (build != "oss") {
const existingLoginPages = await db
.select()

View File

@@ -38,7 +38,7 @@ registry.registerPath({
method: "post",
path: "/resource/{resourceId}/rule/{ruleId}",
description: "Update a resource rule.",
tags: [OpenAPITags.Resource, OpenAPITags.Rule],
tags: [OpenAPITags.PublicResource, OpenAPITags.Rule],
request: {
params: updateResourceRuleParamsSchema,
body: {

View File

@@ -45,7 +45,7 @@ registry.registerPath({
method: "put",
path: "/org/{orgId}/role",
description: "Create a role.",
tags: [OpenAPITags.Org, OpenAPITags.Role],
tags: [OpenAPITags.Role],
request: {
params: createRoleParamsSchema,
body: {

View File

@@ -7,7 +7,7 @@ import { and, eq, inArray, sql } from "drizzle-orm";
import { ActionsEnum } from "@server/auth/actions";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { object, z } from "zod";
import { fromError } from "zod-validation-error";
const listRolesParamsSchema = z.strictObject({
@@ -64,7 +64,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/roles",
description: "List roles.",
tags: [OpenAPITags.Org, OpenAPITags.Role],
tags: [OpenAPITags.Role],
request: {
params: listRolesParamsSchema,
query: listRolesSchema

View File

@@ -58,7 +58,7 @@ registry.registerPath({
method: "put",
path: "/org/{orgId}/site",
description: "Create a new site.",
tags: [OpenAPITags.Site, OpenAPITags.Org],
tags: [OpenAPITags.Site],
request: {
params: createSiteParamsSchema,
body: {

View File

@@ -51,7 +51,7 @@ registry.registerPath({
path: "/org/{orgId}/site/{niceId}",
description:
"Get a site by orgId and niceId. NiceId is a readable ID for the site and unique on a per org basis.",
tags: [OpenAPITags.Org, OpenAPITags.Site],
tags: [OpenAPITags.Site],
request: {
params: z.object({
orgId: z.string(),

View File

@@ -55,7 +55,7 @@ async function getLatestNewtVersion(): Promise<string | null> {
tags = tags.filter((version) => !version.name.includes("rc"));
const latestVersion = tags[0].name;
await cache.set("latestNewtVersion", latestVersion);
await cache.set("latestNewtVersion", latestVersion, 3600);
return latestVersion;
} catch (error: any) {

View File

@@ -35,7 +35,7 @@ registry.registerPath({
path: "/org/{orgId}/pick-site-defaults",
description:
"Return pre-requisite data for creating a site, such as the exit node, subnet, Newt credentials, etc.",
tags: [OpenAPITags.Org, OpenAPITags.Site],
tags: [OpenAPITags.Site],
request: {
params: z.object({
orgId: z.string()

View File

@@ -30,7 +30,7 @@ registry.registerPath({
path: "/site-resource/{siteResourceId}/clients/add",
description:
"Add a single client to a site resource. Clients with a userId cannot be added.",
tags: [OpenAPITags.Resource, OpenAPITags.Client],
tags: [OpenAPITags.PrivateResource, OpenAPITags.Client],
request: {
params: addClientToSiteResourceParamsSchema,
body: {

View File

@@ -30,7 +30,7 @@ registry.registerPath({
method: "post",
path: "/site-resource/{siteResourceId}/roles/add",
description: "Add a single role to a site resource.",
tags: [OpenAPITags.Resource, OpenAPITags.Role],
tags: [OpenAPITags.PrivateResource, OpenAPITags.Role],
request: {
params: addRoleToSiteResourceParamsSchema,
body: {

View File

@@ -30,7 +30,7 @@ registry.registerPath({
method: "post",
path: "/site-resource/{siteResourceId}/users/add",
description: "Add a single user to a site resource.",
tags: [OpenAPITags.Resource, OpenAPITags.User],
tags: [OpenAPITags.PrivateResource, OpenAPITags.User],
request: {
params: addUserToSiteResourceParamsSchema,
body: {

View File

@@ -0,0 +1,247 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import {
db,
clients,
clientSiteResources,
siteResources,
apiKeyOrg
} from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { eq, and, inArray } from "drizzle-orm";
import { OpenAPITags, registry } from "@server/openApi";
import {
rebuildClientAssociationsFromClient,
rebuildClientAssociationsFromSiteResource
} from "@server/lib/rebuildClientAssociations";
const batchAddClientToSiteResourcesParamsSchema = z
.object({
clientId: z.string().transform(Number).pipe(z.number().int().positive())
})
.strict();
const batchAddClientToSiteResourcesBodySchema = z
.object({
siteResourceIds: z
.array(z.number().int().positive())
.min(1, "At least one siteResourceId is required")
})
.strict();
registry.registerPath({
method: "post",
path: "/client/{clientId}/site-resources",
description: "Add a machine client to multiple site resources at once.",
tags: [OpenAPITags.Client],
request: {
params: batchAddClientToSiteResourcesParamsSchema,
body: {
content: {
"application/json": {
schema: batchAddClientToSiteResourcesBodySchema
}
}
}
},
responses: {}
});
export async function batchAddClientToSiteResources(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const apiKey = req.apiKey;
if (!apiKey) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Key not authenticated")
);
}
const parsedParams =
batchAddClientToSiteResourcesParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedBody = batchAddClientToSiteResourcesBodySchema.safeParse(
req.body
);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { clientId } = parsedParams.data;
const { siteResourceIds } = parsedBody.data;
const uniqueSiteResourceIds = [...new Set(siteResourceIds)];
const batchSiteResources = await db
.select()
.from(siteResources)
.where(
inArray(siteResources.siteResourceId, uniqueSiteResourceIds)
);
if (batchSiteResources.length !== uniqueSiteResourceIds.length) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"One or more site resources not found"
)
);
}
if (!apiKey.isRoot) {
const orgIds = [
...new Set(batchSiteResources.map((sr) => sr.orgId))
];
if (orgIds.length > 1) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"All site resources must belong to the same organization"
)
);
}
const orgId = orgIds[0];
const [apiKeyOrgRow] = await db
.select()
.from(apiKeyOrg)
.where(
and(
eq(apiKeyOrg.apiKeyId, apiKey.apiKeyId),
eq(apiKeyOrg.orgId, orgId)
)
)
.limit(1);
if (!apiKeyOrgRow) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Key does not have access to the organization of the specified site resources"
)
);
}
const [clientInOrg] = await db
.select()
.from(clients)
.where(
and(
eq(clients.clientId, clientId),
eq(clients.orgId, orgId)
)
)
.limit(1);
if (!clientInOrg) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Key does not have access to the specified client"
)
);
}
}
const [client] = await db
.select()
.from(clients)
.where(eq(clients.clientId, clientId))
.limit(1);
if (!client) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Client not found")
);
}
if (client.userId !== null) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"This endpoint only supports machine (non-user) clients; the specified client is associated with a user"
)
);
}
const existingEntries = await db
.select({
siteResourceId: clientSiteResources.siteResourceId
})
.from(clientSiteResources)
.where(
and(
eq(clientSiteResources.clientId, clientId),
inArray(
clientSiteResources.siteResourceId,
batchSiteResources.map((sr) => sr.siteResourceId)
)
)
);
const existingSiteResourceIds = new Set(
existingEntries.map((e) => e.siteResourceId)
);
const siteResourcesToAdd = batchSiteResources.filter(
(sr) => !existingSiteResourceIds.has(sr.siteResourceId)
);
if (siteResourcesToAdd.length === 0) {
return next(
createHttpError(
HttpCode.CONFLICT,
"Client is already assigned to all specified site resources"
)
);
}
await db.transaction(async (trx) => {
for (const siteResource of siteResourcesToAdd) {
await trx.insert(clientSiteResources).values({
clientId,
siteResourceId: siteResource.siteResourceId
});
}
await rebuildClientAssociationsFromClient(client, trx);
});
return response(res, {
data: {
addedCount: siteResourcesToAdd.length,
skippedCount:
batchSiteResources.length - siteResourcesToAdd.length,
siteResourceIds: siteResourcesToAdd.map(
(sr) => sr.siteResourceId
)
},
success: true,
error: false,
message: `Client added to ${siteResourcesToAdd.length} site resource(s) successfully`,
status: HttpCode.CREATED
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -114,7 +114,7 @@ registry.registerPath({
method: "put",
path: "/org/{orgId}/site-resource",
description: "Create a new site resource.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
tags: [OpenAPITags.PrivateResource],
request: {
params: createSiteResourceParamsSchema,
body: {

View File

@@ -23,7 +23,7 @@ registry.registerPath({
method: "delete",
path: "/site-resource/{siteResourceId}",
description: "Delete a site resource.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
tags: [OpenAPITags.PrivateResource],
request: {
params: deleteSiteResourceParamsSchema
},

View File

@@ -65,7 +65,7 @@ registry.registerPath({
method: "get",
path: "/site-resource/{siteResourceId}",
description: "Get a specific site resource by siteResourceId.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
tags: [OpenAPITags.PrivateResource],
request: {
params: z.object({
siteResourceId: z.number(),
@@ -80,7 +80,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/site/{siteId}/resource/nice/{niceId}",
description: "Get a specific site resource by niceId.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
tags: [OpenAPITags.PrivateResource],
request: {
params: z.object({
niceId: z.string(),

View File

@@ -15,4 +15,5 @@ export * from "./addUserToSiteResource";
export * from "./removeUserFromSiteResource";
export * from "./setSiteResourceClients";
export * from "./addClientToSiteResource";
export * from "./batchAddClientToSiteResources";
export * from "./removeClientFromSiteResource";

View File

@@ -112,7 +112,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/site-resources",
description: "List all site resources for an organization.",
tags: [OpenAPITags.Client, OpenAPITags.Org],
tags: [OpenAPITags.PrivateResource],
request: {
params: listAllSiteResourcesByOrgParamsSchema,
query: listAllSiteResourcesByOrgQuerySchema

View File

@@ -39,7 +39,7 @@ registry.registerPath({
method: "get",
path: "/site-resource/{siteResourceId}/clients",
description: "List all clients for a site resource.",
tags: [OpenAPITags.Resource, OpenAPITags.Client],
tags: [OpenAPITags.PrivateResource, OpenAPITags.Client],
request: {
params: listSiteResourceClientsSchema
},

View File

@@ -40,7 +40,7 @@ registry.registerPath({
method: "get",
path: "/site-resource/{siteResourceId}/roles",
description: "List all roles for a site resource.",
tags: [OpenAPITags.Resource, OpenAPITags.Role],
tags: [OpenAPITags.PrivateResource, OpenAPITags.Role],
request: {
params: listSiteResourceRolesSchema
},

Some files were not shown because too many files have changed in this diff Show More