Merge branch 'dev' into private-site-ha

This commit is contained in:
Owen
2026-04-09 17:39:45 -04:00
282 changed files with 22523 additions and 4747 deletions

View File

@@ -208,7 +208,7 @@ export async function listAccessTokens(
.where(
or(
eq(userResources.userId, req.user!.userId),
eq(roleResources.roleId, req.userOrgRoleId!)
inArray(roleResources.roleId, req.userOrgRoleIds!)
)
);
} else {

View File

@@ -91,3 +91,50 @@ export type QueryAccessAuditLogResponse = {
locations: string[];
};
};
export type QueryConnectionAuditLogResponse = {
log: {
sessionId: string;
siteResourceId: number | null;
orgId: string | null;
siteId: number | null;
clientId: number | null;
userId: string | null;
sourceAddr: string;
destAddr: string;
protocol: string;
startedAt: number;
endedAt: number | null;
bytesTx: number | null;
bytesRx: number | null;
resourceName: string | null;
resourceNiceId: string | null;
siteName: string | null;
siteNiceId: string | null;
clientName: string | null;
clientNiceId: string | null;
clientType: string | null;
userEmail: string | null;
}[];
pagination: {
total: number;
limit: number;
offset: number;
};
filterAttributes: {
protocols: string[];
destAddrs: string[];
clients: {
id: number;
name: string;
}[];
resources: {
id: number;
name: string | null;
}[];
users: {
id: string;
email: string | null;
}[];
};
};

View File

@@ -5,6 +5,8 @@ import cache from "#dynamic/lib/cache";
import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs";
import { stripPortFromHost } from "@server/lib/ip";
import { sanitizeString } from "@server/lib/sanitize";
/**
Reasons:
@@ -253,24 +255,23 @@ export async function logRequestAudit(
// Add to buffer instead of writing directly to DB
auditLogBuffer.push({
timestamp,
orgId: data.orgId,
actorType,
actor,
actorId,
metadata,
orgId: sanitizeString(data.orgId),
actorType: sanitizeString(actorType),
actor: sanitizeString(actor),
actorId: sanitizeString(actorId),
metadata: sanitizeString(metadata),
action: data.action,
resourceId: data.resourceId,
reason: data.reason,
location: data.location,
originalRequestURL: body.originalRequestURL,
scheme: body.scheme,
host: body.host,
path: body.path,
method: body.method,
ip: clientIp,
location: sanitizeString(data.location),
originalRequestURL: sanitizeString(body.originalRequestURL) ?? "",
scheme: sanitizeString(body.scheme) ?? "",
host: sanitizeString(body.host) ?? "",
path: sanitizeString(body.path) ?? "",
method: sanitizeString(body.method) ?? "",
ip: sanitizeString(clientIp),
tls: body.tls
});
// Flush immediately if buffer is full, otherwise schedule a flush
if (auditLogBuffer.length >= BATCH_SIZE) {
// Fire and forget - don't block the caller

View File

@@ -1,4 +1,37 @@
import { assertEquals } from "@test/assert";
import { REGIONS } from "@server/db/regions";
function isIpInRegion(
ipCountryCode: string | undefined,
checkRegionCode: string
): boolean {
if (!ipCountryCode) {
return false;
}
const upperCode = ipCountryCode.toUpperCase();
for (const region of REGIONS) {
// Check if it's a top-level region (continent)
if (region.id === checkRegionCode) {
for (const subregion of region.includes) {
if (subregion.countries.includes(upperCode)) {
return true;
}
}
return false;
}
// Check subregions
for (const subregion of region.includes) {
if (subregion.id === checkRegionCode) {
return subregion.countries.includes(upperCode);
}
}
}
return false;
}
function isPathAllowed(pattern: string, path: string): boolean {
// Normalize and split paths into segments
@@ -272,12 +305,71 @@ function runTests() {
"Root path should not match non-root path"
);
console.log("All tests passed!");
console.log("All path matching tests passed!");
}
function runRegionTests() {
console.log("\nRunning isIpInRegion tests...");
// Test undefined country code
assertEquals(
isIpInRegion(undefined, "150"),
false,
"Undefined country code should return false"
);
// Test subregion matching (Western Europe)
assertEquals(
isIpInRegion("DE", "155"),
true,
"Country should match its subregion"
);
assertEquals(
isIpInRegion("GB", "155"),
false,
"Country should NOT match wrong subregion"
);
// Test continent matching (Europe)
assertEquals(
isIpInRegion("DE", "150"),
true,
"Country should match its continent"
);
assertEquals(
isIpInRegion("GB", "150"),
true,
"Different European country should match Europe"
);
assertEquals(
isIpInRegion("US", "150"),
false,
"Non-European country should NOT match Europe"
);
// Test case insensitivity
assertEquals(
isIpInRegion("de", "155"),
true,
"Lowercase country code should work"
);
// Test invalid region code
assertEquals(
isIpInRegion("DE", "999"),
false,
"Invalid region code should return false"
);
console.log("All region tests passed!");
}
// Run all tests
try {
runTests();
runRegionTests();
console.log("\n✅ All tests passed!");
} catch (error) {
console.error("Test failed:", error);
console.error("Test failed:", error);
process.exit(1);
}

View File

@@ -4,11 +4,11 @@ import {
getResourceByDomain,
getResourceRules,
getRoleResourceAccess,
getUserOrgRole,
getUserResourceAccess,
getOrgLoginPage,
getUserSessionWithUser
} from "@server/db/queries/verifySessionQueries";
import { getUserOrgRoles } from "@server/lib/userOrgRoles";
import {
LoginPage,
Org,
@@ -30,13 +30,13 @@ import { z } from "zod";
import { fromError } from "zod-validation-error";
import { getCountryCodeForIp } from "@server/lib/geoip";
import { getAsnForIp } from "@server/lib/asn";
import { getOrgTierData } from "#dynamic/lib/billing";
import { verifyPassword } from "@server/auth/password";
import {
checkOrgAccessPolicy,
enforceResourceSessionLength
} from "#dynamic/lib/checkOrgAccessPolicy";
import { logRequestAudit } from "./logRequestAudit";
import { REGIONS } from "@server/db/regions";
import { localCache } from "#dynamic/lib/cache";
import { APP_VERSION } from "@server/lib/consts";
import { isSubscribed } from "#dynamic/lib/isSubscribed";
@@ -797,7 +797,8 @@ async function notAllowed(
) {
let loginPage: LoginPage | null = null;
if (orgId) {
const subscribed = await isSubscribed( // this is fine because the org login page is only a saas feature
const subscribed = await isSubscribed(
// this is fine because the org login page is only a saas feature
orgId,
tierMatrix.loginPageDomain
);
@@ -854,7 +855,10 @@ async function headerAuthChallenged(
) {
let loginPage: LoginPage | null = null;
if (orgId) {
const subscribed = await isSubscribed(orgId, tierMatrix.loginPageDomain); // this is fine because the org login page is only a saas feature
const subscribed = await isSubscribed(
orgId,
tierMatrix.loginPageDomain
); // this is fine because the org login page is only a saas feature
if (subscribed) {
loginPage = await getOrgLoginPage(orgId);
}
@@ -916,9 +920,9 @@ async function isUserAllowedToAccessResource(
return null;
}
const userOrgRole = await getUserOrgRole(user.userId, resource.orgId);
const userOrgRoles = await getUserOrgRoles(user.userId, resource.orgId);
if (!userOrgRole) {
if (!userOrgRoles.length) {
return null;
}
@@ -936,15 +940,14 @@ async function isUserAllowedToAccessResource(
const roleResourceAccess = await getRoleResourceAccess(
resource.resourceId,
userOrgRole.roleId
userOrgRoles.map((r) => r.roleId)
);
if (roleResourceAccess) {
if (roleResourceAccess && roleResourceAccess.length > 0) {
return {
username: user.username,
email: user.email,
name: user.name,
role: userOrgRole.roleName
role: userOrgRoles.map((r) => r.roleName).join(", ")
};
}
@@ -958,7 +961,7 @@ async function isUserAllowedToAccessResource(
username: user.username,
email: user.email,
name: user.name,
role: userOrgRole.roleName
role: userOrgRoles.map((r) => r.roleName).join(", ")
};
}
@@ -1020,6 +1023,12 @@ async function checkRules(
(await isIpInAsn(ipAsn, rule.value))
) {
return rule.action as any;
} else if (
clientIp &&
rule.match == "REGION" &&
(await isIpInRegion(ipCC, rule.value))
) {
return rule.action as any;
}
}
@@ -1205,6 +1214,45 @@ async function isIpInAsn(
return match;
}
export async function isIpInRegion(
ipCountryCode: string | undefined,
checkRegionCode: string
): Promise<boolean> {
if (!ipCountryCode) {
return false;
}
const upperCode = ipCountryCode.toUpperCase();
for (const region of REGIONS) {
// Check if it's a top-level region (continent)
if (region.id === checkRegionCode) {
for (const subregion of region.includes) {
if (subregion.countries.includes(upperCode)) {
logger.debug(`Country ${upperCode} is in region ${region.id} (${region.name})`);
return true;
}
}
logger.debug(`Country ${upperCode} is not in region ${region.id} (${region.name})`);
return false;
}
// Check subregions
for (const subregion of region.includes) {
if (subregion.id === checkRegionCode) {
if (subregion.countries.includes(upperCode)) {
logger.debug(`Country ${upperCode} is in region ${subregion.id} (${subregion.name})`);
return true;
}
logger.debug(`Country ${upperCode} is not in region ${subregion.id} (${subregion.name})`);
return false;
}
}
}
return false;
}
async function getAsnFromIp(ip: string): Promise<number | undefined> {
const asnCacheKey = `asn:${ip}`;

View File

@@ -92,7 +92,7 @@ export async function createClient(
const { orgId } = parsedParams.data;
if (req.user && !req.userOrgRoleId) {
if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);
@@ -234,7 +234,7 @@ export async function createClient(
clientId: newClient.clientId
});
if (req.user && req.userOrgRoleId != adminRole.roleId) {
if (req.user && !req.userOrgRoleIds?.includes(adminRole.roleId)) {
// make sure the user can access the client
trx.insert(userClients).values({
userId: req.user.userId,

View File

@@ -70,7 +70,7 @@ async function getLatestOlmVersion(): Promise<string | null> {
tags = tags.filter((version) => !version.name.includes("rc"));
const latestVersion = tags[0].name;
olmVersionCache.set("latestOlmVersion", latestVersion);
olmVersionCache.set("latestOlmVersion", latestVersion, 3600);
return latestVersion;
} catch (error: any) {
@@ -297,7 +297,7 @@ export async function listClients(
.where(
or(
eq(userClients.userId, req.user!.userId),
eq(roleClients.roleId, req.userOrgRoleId!)
inArray(roleClients.roleId, req.userOrgRoleIds!)
)
);
} else {

View File

@@ -71,7 +71,7 @@ async function getLatestOlmVersion(): Promise<string | null> {
tags = tags.filter((version) => !version.name.includes("rc"));
const latestVersion = tags[0].name;
olmVersionCache.set("latestOlmVersion", latestVersion);
olmVersionCache.set("latestOlmVersion", latestVersion, 3600);
return latestVersion;
} catch (error: any) {
@@ -316,7 +316,7 @@ export async function listUserDevices(
.where(
or(
eq(userClients.userId, req.user!.userId),
eq(roleClients.roleId, req.userOrgRoleId!)
inArray(roleClients.roleId, req.userOrgRoleIds!)
)
);
} else {

View File

@@ -1,15 +1,54 @@
import { sendToClient } from "#dynamic/routers/ws";
import { db, olms, Transaction } from "@server/db";
import { db, newts, olms } from "@server/db";
import {
Alias,
convertSubnetProxyTargetsV2ToV1,
SubnetProxyTarget,
SubnetProxyTargetV2
} from "@server/lib/ip";
import { canCompress } from "@server/lib/clientVersionChecks";
import { Alias, SubnetProxyTarget } from "@server/lib/ip";
import logger from "@server/logger";
import { eq } from "drizzle-orm";
import semver from "semver";
const NEWT_V2_TARGETS_VERSION = ">=1.10.3";
export async function convertTargetsIfNessicary(
newtId: string,
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[]
) {
// get the newt
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.newtId, newtId));
if (!newt) {
throw new Error(`No newt found for id: ${newtId}`);
}
// check the semver
if (
newt.version &&
!semver.satisfies(newt.version, NEWT_V2_TARGETS_VERSION)
) {
logger.debug(
`addTargets Newt version ${newt.version} does not support targets v2 falling back`
);
targets = convertSubnetProxyTargetsV2ToV1(
targets as SubnetProxyTargetV2[]
);
}
return targets;
}
export async function addTargets(
newtId: string,
targets: SubnetProxyTarget[],
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[],
version?: string | null
) {
targets = await convertTargetsIfNessicary(newtId, targets);
await sendToClient(
newtId,
{
@@ -22,9 +61,11 @@ export async function addTargets(
export async function removeTargets(
newtId: string,
targets: SubnetProxyTarget[],
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[],
version?: string | null
) {
targets = await convertTargetsIfNessicary(newtId, targets);
await sendToClient(
newtId,
{
@@ -38,11 +79,39 @@ export async function removeTargets(
export async function updateTargets(
newtId: string,
targets: {
oldTargets: SubnetProxyTarget[];
newTargets: SubnetProxyTarget[];
oldTargets: SubnetProxyTarget[] | SubnetProxyTargetV2[];
newTargets: SubnetProxyTarget[] | SubnetProxyTargetV2[];
},
version?: string | null
) {
// get the newt
const [newt] = await db
.select()
.from(newts)
.where(eq(newts.newtId, newtId));
if (!newt) {
logger.error(`addTargetsL No newt found for id: ${newtId}`);
return;
}
// check the semver
if (
newt.version &&
!semver.satisfies(newt.version, NEWT_V2_TARGETS_VERSION)
) {
logger.debug(
`addTargets Newt version ${newt.version} does not support targets v2 falling back`
);
targets = {
oldTargets: convertSubnetProxyTargetsV2ToV1(
targets.oldTargets as SubnetProxyTargetV2[]
),
newTargets: convertSubnetProxyTargetsV2ToV1(
targets.newTargets as SubnetProxyTargetV2[]
)
};
}
await sendToClient(
newtId,
{

View File

@@ -102,6 +102,8 @@ authenticated.put(
logActionAudit(ActionsEnum.createSite),
site.createSite
);
authenticated.get(
"/org/:orgId/sites",
verifyOrgAccess,
@@ -644,6 +646,7 @@ authenticated.delete(
logActionAudit(ActionsEnum.deleteRole),
role.deleteRole
);
authenticated.post(
"/role/:roleId/add/:userId",
verifyRoleAccess,
@@ -651,7 +654,7 @@ authenticated.post(
verifyLimits,
verifyUserHasAction(ActionsEnum.addUserRole),
logActionAudit(ActionsEnum.addUserRole),
user.addUserRole
user.addUserRoleLegacy
);
authenticated.post(
@@ -793,6 +796,11 @@ unauthenticated.get(
// );
unauthenticated.get("/user", verifySessionMiddleware, user.getUser);
unauthenticated.post(
"/user/locale",
verifySessionMiddleware,
user.updateUserLocale
);
unauthenticated.get("/my-device", verifySessionMiddleware, user.myDevice);
authenticated.get("/users", verifyUserIsServerAdmin, user.adminListUsers);
@@ -1202,6 +1210,22 @@ authRouter.post(
}),
newt.getNewtToken
);
authRouter.post(
"/newt/register",
rateLimit({
windowMs: 15 * 60 * 1000,
max: 30,
keyGenerator: (req) =>
`newtRegister:${req.body.provisioningKey?.split(".")[0] || ipKeyGenerator(req.ip || "")}`,
handler: (req, res, next) => {
const message = `You can only register a newt ${30} times every ${15} minutes. Please try again later.`;
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
},
store: createStore()
}),
newt.registerNewt
);
authRouter.post(
"/olm/get-token",
rateLimit({

View File

@@ -1,7 +1,6 @@
import { Request, Response, NextFunction } from "express";
import { eq, sql } from "drizzle-orm";
import { sites } from "@server/db";
import { db } from "@server/db";
import { sql } from "drizzle-orm";
import { db, DB_TYPE } from "@server/db";
import logger from "@server/logger";
import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
@@ -31,7 +30,10 @@ const MAX_RETRIES = 3;
const BASE_DELAY_MS = 50;
// How often to flush accumulated bandwidth data to the database
const FLUSH_INTERVAL_MS = 30_000; // 30 seconds
const FLUSH_INTERVAL_MS = 300_000; // 300 seconds
// Maximum number of sites to include in a single batch UPDATE statement
const BATCH_CHUNK_SIZE = 250;
// In-memory accumulator: publicKey -> AccumulatorEntry
let accumulator = new Map<string, AccumulatorEntry>();
@@ -75,13 +77,37 @@ async function withDeadlockRetry<T>(
}
}
/**
* Execute a raw SQL query that returns rows, in a way that works across both
* the PostgreSQL driver (which exposes `execute`) and the SQLite driver (which
* exposes `all`). Drizzle's typed query builder doesn't support bulk
* UPDATE … FROM (VALUES …) natively, so we drop to raw SQL here.
*/
async function dbQueryRows<T extends Record<string, unknown>>(
query: Parameters<(typeof sql)["join"]>[0][number]
): Promise<T[]> {
const anyDb = db as any;
if (typeof anyDb.execute === "function") {
// PostgreSQL (node-postgres via Drizzle) — returns { rows: [...] } or an array
const result = await anyDb.execute(query);
return (Array.isArray(result) ? result : (result.rows ?? [])) as T[];
}
// SQLite (better-sqlite3 via Drizzle) — returns an array directly
return (await anyDb.all(query)) as T[];
}
function isSQLite(): boolean {
return DB_TYPE == "sqlite";
}
/**
* Flush all accumulated site bandwidth data to the database.
*
* Swaps out the accumulator before writing so that any bandwidth messages
* received during the flush are captured in the new accumulator rather than
* being lost or causing contention. Entries that fail to write are re-queued
* back into the accumulator so they will be retried on the next flush.
* being lost or causing contention. Sites are updated in chunks via a single
* batch UPDATE per chunk. Failed chunks are discarded — exact per-flush
* accuracy is not critical and re-queuing is not worth the added complexity.
*
* This function is exported so that the application's graceful-shutdown
* cleanup handler can call it before the process exits.
@@ -108,76 +134,93 @@ export async function flushSiteBandwidthToDb(): Promise<void> {
`Flushing accumulated bandwidth data for ${sortedEntries.length} site(s) to the database`
);
// Aggregate billing usage by org, collected during the DB update loop.
// Build a lookup so post-processing can reach each entry by publicKey.
const snapshotMap = new Map(sortedEntries);
// Aggregate billing usage by org across all chunks.
const orgUsageMap = new Map<string, number>();
for (const [publicKey, { bytesIn, bytesOut, exitNodeId, calcUsage }] of sortedEntries) {
// Process in chunks so individual queries stay at a reasonable size.
for (let i = 0; i < sortedEntries.length; i += BATCH_CHUNK_SIZE) {
const chunk = sortedEntries.slice(i, i + BATCH_CHUNK_SIZE);
const chunkEnd = i + chunk.length - 1;
let rows: { orgId: string; pubKey: string }[] = [];
try {
const updatedSite = await withDeadlockRetry(async () => {
const [result] = await db
.update(sites)
.set({
megabytesOut: sql`COALESCE(${sites.megabytesOut}, 0) + ${bytesIn}`,
megabytesIn: sql`COALESCE(${sites.megabytesIn}, 0) + ${bytesOut}`,
lastBandwidthUpdate: currentTime,
})
.where(eq(sites.pubKey, publicKey))
.returning({
orgId: sites.orgId,
siteId: sites.siteId
});
return result;
}, `flush bandwidth for site ${publicKey}`);
if (updatedSite) {
if (exitNodeId) {
const notAllowed = await checkExitNodeOrg(
exitNodeId,
updatedSite.orgId
);
if (notAllowed) {
logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
);
// Skip usage tracking for this site but continue
// processing the rest.
continue;
rows = await withDeadlockRetry(async () => {
if (isSQLite()) {
// SQLite: one UPDATE per row — no need for batch efficiency here.
const results: { orgId: string; pubKey: string }[] = [];
for (const [publicKey, { bytesIn, bytesOut }] of chunk) {
const result = await dbQueryRows<{
orgId: string;
pubKey: string;
}>(sql`
UPDATE sites
SET
"bytesOut" = COALESCE("bytesOut", 0) + ${bytesIn},
"bytesIn" = COALESCE("bytesIn", 0) + ${bytesOut},
"lastBandwidthUpdate" = ${currentTime}
WHERE "pubKey" = ${publicKey}
RETURNING "orgId", "pubKey"
`);
results.push(...result);
}
return results;
}
if (calcUsage) {
const totalBandwidth = bytesIn + bytesOut;
const current = orgUsageMap.get(updatedSite.orgId) ?? 0;
orgUsageMap.set(updatedSite.orgId, current + totalBandwidth);
}
}
// PostgreSQL: batch UPDATE … FROM (VALUES …) — single round-trip per chunk.
const valuesList = chunk.map(([publicKey, { bytesIn, bytesOut }]) =>
sql`(${publicKey}::text, ${bytesIn}::real, ${bytesOut}::real)`
);
const valuesClause = sql.join(valuesList, sql`, `);
return dbQueryRows<{ orgId: string; pubKey: string }>(sql`
UPDATE sites
SET
"bytesOut" = COALESCE("bytesOut", 0) + v.bytes_in,
"bytesIn" = COALESCE("bytesIn", 0) + v.bytes_out,
"lastBandwidthUpdate" = ${currentTime}
FROM (VALUES ${valuesClause}) AS v(pub_key, bytes_in, bytes_out)
WHERE sites."pubKey" = v.pub_key
RETURNING sites."orgId" AS "orgId", sites."pubKey" AS "pubKey"
`);
}, `flush bandwidth chunk [${i}${chunkEnd}]`);
} catch (error) {
logger.error(
`Failed to flush bandwidth for site ${publicKey}:`,
`Failed to flush bandwidth chunk [${i}${chunkEnd}], discarding ${chunk.length} site(s):`,
error
);
// Discard the chunk — exact per-flush accuracy is not critical.
continue;
}
// Re-queue the failed entry so it is retried on the next flush
// rather than silently dropped.
const existing = accumulator.get(publicKey);
if (existing) {
existing.bytesIn += bytesIn;
existing.bytesOut += bytesOut;
} else {
accumulator.set(publicKey, {
bytesIn,
bytesOut,
exitNodeId,
calcUsage
});
// Collect billing usage from the returned rows.
for (const { orgId, pubKey } of rows) {
const entry = snapshotMap.get(pubKey);
if (!entry) continue;
const { bytesIn, bytesOut, exitNodeId, calcUsage } = entry;
if (exitNodeId) {
const notAllowed = await checkExitNodeOrg(exitNodeId, orgId);
if (notAllowed) {
logger.warn(
`Exit node ${exitNodeId} is not allowed for org ${orgId}`
);
continue;
}
}
if (calcUsage) {
const current = orgUsageMap.get(orgId) ?? 0;
orgUsageMap.set(orgId, current + bytesIn + bytesOut);
}
}
}
// Process billing usage updates outside the site-update loop to keep
// lock scope small and concerns separated.
// Process billing usage updates after all chunks are written.
if (orgUsageMap.size > 0) {
// Sort org IDs for consistent lock ordering.
const sortedOrgIds = [...orgUsageMap.keys()].sort();
for (const orgId of sortedOrgIds) {

View File

@@ -25,7 +25,8 @@ const bodySchema = z.strictObject({
namePath: z.string().optional(),
scopes: z.string().nonempty(),
autoProvision: z.boolean().optional(),
tags: z.string().optional()
tags: z.string().optional(),
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc")
});
export type CreateIdpResponse = {
@@ -77,7 +78,8 @@ export async function createOidcIdp(
namePath,
name,
autoProvision,
tags
tags,
variant
} = parsedBody.data;
if (
@@ -121,7 +123,8 @@ export async function createOidcIdp(
scopes,
identifierPath,
emailPath,
namePath
namePath,
variant
});
});

View File

@@ -31,7 +31,8 @@ const bodySchema = z.strictObject({
autoProvision: z.boolean().optional(),
defaultRoleMapping: z.string().optional(),
defaultOrgMapping: z.string().optional(),
tags: z.string().optional()
tags: z.string().optional(),
variant: z.enum(["oidc", "google", "azure"]).optional()
});
export type UpdateIdpResponse = {
@@ -96,7 +97,8 @@ export async function updateOidcIdp(
autoProvision,
defaultRoleMapping,
defaultOrgMapping,
tags
tags,
variant
} = parsedBody.data;
if (process.env.IDENTITY_PROVIDER_MODE === "org") {
@@ -159,7 +161,8 @@ export async function updateOidcIdp(
scopes,
identifierPath,
emailPath,
namePath
namePath,
variant
};
keysToUpdate = Object.keys(configData).filter(

View File

@@ -13,6 +13,7 @@ import {
orgs,
Role,
roles,
userOrgRoles,
userOrgs,
users
} from "@server/db";
@@ -35,11 +36,13 @@ import { usageService } from "@server/lib/billing/usageService";
import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { isSubscribed } from "#dynamic/lib/isSubscribed";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import {
assignUserToOrg,
removeUserFromOrg
} from "@server/lib/userOrg";
import { unwrapRoleMapping } from "@app/lib/idpRoleMapping";
const ensureTrailingSlash = (url: string): string => {
return url;
@@ -366,7 +369,7 @@ export async function validateOidcCallback(
const defaultRoleMapping = existingIdp.idp.defaultRoleMapping;
const defaultOrgMapping = existingIdp.idp.defaultOrgMapping;
const userOrgInfo: { orgId: string; roleId: number }[] = [];
const userOrgInfo: { orgId: string; roleIds: number[] }[] = [];
for (const org of allOrgs) {
const [idpOrgRes] = await db
.select()
@@ -378,8 +381,6 @@ export async function validateOidcCallback(
)
);
let roleId: number | undefined = undefined;
const orgMapping = idpOrgRes?.orgMapping || defaultOrgMapping;
const hydratedOrgMapping = hydrateOrgMapping(
orgMapping,
@@ -404,38 +405,55 @@ export async function validateOidcCallback(
idpOrgRes?.roleMapping || defaultRoleMapping;
if (roleMapping) {
logger.debug("Role Mapping", { roleMapping });
const roleName = jmespath.search(claims, roleMapping);
const roleMappingJmes = unwrapRoleMapping(
roleMapping
).evaluationExpression;
const roleMappingResult = jmespath.search(
claims,
roleMappingJmes
);
const roleNames = normalizeRoleMappingResult(
roleMappingResult
);
if (!roleName) {
logger.error("Role name not found in the ID token", {
roleName
const supportsMultiRole = await isLicensedOrSubscribed(
org.orgId,
tierMatrix.fullRbac
);
const effectiveRoleNames = supportsMultiRole
? roleNames
: roleNames.slice(0, 1);
if (!effectiveRoleNames.length) {
logger.error("Role mapping returned no valid roles", {
roleMappingResult
});
continue;
}
const [roleRes] = await db
const roleRes = await db
.select()
.from(roles)
.where(
and(
eq(roles.orgId, org.orgId),
eq(roles.name, roleName)
inArray(roles.name, effectiveRoleNames)
)
);
if (!roleRes) {
logger.error("Role not found", {
if (!roleRes.length) {
logger.error("No mapped roles found in organization", {
orgId: org.orgId,
roleName
roleNames: effectiveRoleNames
});
continue;
}
roleId = roleRes.roleId;
const roleIds = [...new Set(roleRes.map((r) => r.roleId))];
userOrgInfo.push({
orgId: org.orgId,
roleId
roleIds
});
}
}
@@ -570,32 +588,28 @@ export async function validateOidcCallback(
}
}
// Update roles for existing auto-provisioned orgs where the role has changed
const orgsToUpdate = autoProvisionedOrgs.filter(
(currentOrg) => {
const newOrg = userOrgInfo.find(
(newOrg) => newOrg.orgId === currentOrg.orgId
);
return newOrg && newOrg.roleId !== currentOrg.roleId;
}
);
// Sync roles 1:1 with IdP policy for existing auto-provisioned orgs
for (const currentOrg of autoProvisionedOrgs) {
const newRole = userOrgInfo.find(
(newOrg) => newOrg.orgId === currentOrg.orgId
);
if (!newRole) continue;
if (orgsToUpdate.length > 0) {
for (const org of orgsToUpdate) {
const newRole = userOrgInfo.find(
(newOrg) => newOrg.orgId === org.orgId
await trx
.delete(userOrgRoles)
.where(
and(
eq(userOrgRoles.userId, userId!),
eq(userOrgRoles.orgId, currentOrg.orgId)
)
);
if (newRole) {
await trx
.update(userOrgs)
.set({ roleId: newRole.roleId })
.where(
and(
eq(userOrgs.userId, userId!),
eq(userOrgs.orgId, org.orgId)
)
);
}
for (const roleId of newRole.roleIds) {
await trx.insert(userOrgRoles).values({
userId: userId!,
orgId: currentOrg.orgId,
roleId
});
}
}
@@ -609,6 +623,10 @@ export async function validateOidcCallback(
if (orgsToAdd.length > 0) {
for (const org of orgsToAdd) {
if (org.roleIds.length === 0) {
continue;
}
const [fullOrg] = await trx
.select()
.from(orgs)
@@ -619,9 +637,9 @@ export async function validateOidcCallback(
{
orgId: org.orgId,
userId: userId!,
roleId: org.roleId,
autoProvisioned: true,
},
org.roleIds,
trx
);
}
@@ -748,3 +766,25 @@ function hydrateOrgMapping(
}
return orgMapping.split("{{orgId}}").join(orgId);
}
function normalizeRoleMappingResult(
result: unknown
): string[] {
if (typeof result === "string") {
const role = result.trim();
return role ? [role] : [];
}
if (Array.isArray(result)) {
return [
...new Set(
result
.filter((value): value is string => typeof value === "string")
.map((value) => value.trim())
.filter(Boolean)
)
];
}
return [];
}

View File

@@ -16,6 +16,7 @@ import {
verifyApiKey,
verifyApiKeyOrgAccess,
verifyApiKeyHasAction,
verifyApiKeyCanSetUserOrgRoles,
verifyApiKeySiteAccess,
verifyApiKeyResourceAccess,
verifyApiKeyTargetAccess,
@@ -135,6 +136,13 @@ authenticated.post(
logActionAudit(ActionsEnum.updateSite),
site.updateSite
);
authenticated.post(
"/org/:orgId/reset-bandwidth",
verifyApiKeyOrgAccess,
verifyApiKeyHasAction(ActionsEnum.resetSiteBandwidth),
logActionAudit(ActionsEnum.resetSiteBandwidth),
org.resetOrgBandwidth
);
authenticated.delete(
"/site/:siteId",
@@ -588,7 +596,7 @@ authenticated.post(
verifyLimits,
verifyApiKeyHasAction(ActionsEnum.addUserRole),
logActionAudit(ActionsEnum.addUserRole),
user.addUserRole
user.addUserRoleLegacy
);
authenticated.post(

View File

@@ -18,8 +18,8 @@ import { eq, and } from "drizzle-orm";
import config from "@server/lib/config";
import {
formatEndpoint,
generateSubnetProxyTargets,
SubnetProxyTarget
generateSubnetProxyTargetV2,
SubnetProxyTargetV2
} from "@server/lib/ip";
export async function buildClientConfigurationForNewtClient(
@@ -148,7 +148,7 @@ export async function buildClientConfigurationForNewtClient(
.where(eq(siteNetworks.siteId, siteId))
.then((rows) => rows.map((r) => r.siteResources));
const targetsToSend: SubnetProxyTarget[] = [];
const targetsToSend: SubnetProxyTargetV2[] = [];
for (const resource of allSiteResources) {
// Get clients associated with this specific resource
@@ -173,12 +173,14 @@ export async function buildClientConfigurationForNewtClient(
)
);
const resourceTargets = generateSubnetProxyTargets(
const resourceTarget = generateSubnetProxyTargetV2(
resource,
resourceClients
);
targetsToSend.push(...resourceTargets);
if (resourceTarget) {
targetsToSend.push(resourceTarget);
}
}
return {

View File

@@ -46,7 +46,7 @@ export async function createNewt(
const { newtId, secret } = parsedBody.data;
if (req.user && !req.userOrgRoleId) {
if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);

View File

@@ -1,6 +1,8 @@
import { generateSessionToken } from "@server/auth/sessions/app";
import { db } from "@server/db";
import { db, newtSessions } from "@server/db";
import { newts } from "@server/db";
import { getOrCreateCachedToken } from "#dynamic/lib/tokenCache";
import { EXPIRES } from "@server/auth/sessions/newt";
import HttpCode from "@server/types/HttpCode";
import response from "@server/lib/response";
import { eq } from "drizzle-orm";
@@ -92,8 +94,19 @@ export async function getNewtToken(
);
}
const resToken = generateSessionToken();
await createNewtSession(resToken, existingNewt.newtId);
// Return a cached token if one exists to prevent thundering herd on
// simultaneous restarts; falls back to creating a fresh session when
// Redis is unavailable or the cache has expired.
const resToken = await getOrCreateCachedToken(
`newt:token_cache:${existingNewt.newtId}`,
config.getRawConfig().server.secret!,
Math.floor(EXPIRES / 1000),
async () => {
const token = generateSessionToken();
await createNewtSession(token, existingNewt.newtId);
return token;
}
);
return response<{ token: string; serverVersion: string }>(res, {
data: {

View File

@@ -0,0 +1,13 @@
import { MessageHandler } from "@server/routers/ws";
export async function flushConnectionLogToDb(): Promise<void> {
return;
}
export async function cleanUpOldLogs(orgId: string, retentionDays: number) {
return;
}
export const handleConnectionLogMessage: MessageHandler = async (context) => {
return;
};

View File

@@ -6,14 +6,9 @@ import { db, ExitNode, exitNodes, Newt, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { sendToExitNode } from "#dynamic/lib/exitNodes";
import { buildClientConfigurationForNewtClient } from "./buildConfiguration";
import { convertTargetsIfNessicary } from "../client/targets";
import { canCompress } from "@server/lib/clientVersionChecks";
const inputSchema = z.object({
publicKey: z.string(),
port: z.int().positive()
});
type Input = z.infer<typeof inputSchema>;
import config from "@server/lib/config";
export const handleGetConfigMessage: MessageHandler = async (context) => {
const { message, client, sendToClient } = context;
@@ -33,16 +28,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
return;
}
const parsed = inputSchema.safeParse(message.data);
if (!parsed.success) {
logger.error(
"handleGetConfigMessage: Invalid input: " +
fromError(parsed.error).toString()
);
return;
}
const { publicKey, port } = message.data as Input;
const { publicKey, port, chainId } = message.data;
const siteId = newt.siteId;
// Get the current site data
@@ -70,7 +56,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
if (existingSite.lastHolePunch && now - existingSite.lastHolePunch > 5) {
logger.warn(
`handleGetConfigMessage: Site ${existingSite.siteId} last hole punch is too old, skipping`
`Site last hole punch is too old; skipping this register. The site is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`
);
return;
}
@@ -127,13 +113,16 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
exitNode
);
const targetsToSend = await convertTargetsIfNessicary(newt.newtId, targets);
return {
message: {
type: "newt/wg/receive-config",
data: {
ipAddress: site.address,
peers,
targets
targets: targetsToSend,
chainId: chainId
}
},
options: {

View File

@@ -6,7 +6,9 @@ import logger from "@server/logger";
/**
* Handles disconnecting messages from sites to show disconnected in the ui
*/
export const handleNewtDisconnectingMessage: MessageHandler = async (context) => {
export const handleNewtDisconnectingMessage: MessageHandler = async (
context
) => {
const { message, client: c, sendToClient } = context;
const newt = c as Newt;
@@ -27,7 +29,7 @@ export const handleNewtDisconnectingMessage: MessageHandler = async (context) =>
.set({
online: false
})
.where(eq(sites.siteId, sites.siteId));
.where(eq(sites.siteId, newt.siteId));
} catch (error) {
logger.error("Error handling disconnecting message", { error });
}

View File

@@ -1,15 +1,20 @@
import { db, newts, sites } from "@server/db";
import { hasActiveConnections, getClientConfigVersion } from "#dynamic/routers/ws";
import { db, newts, sites, targetHealthCheck, targets } from "@server/db";
import {
hasActiveConnections,
getClientConfigVersion
} from "#dynamic/routers/ws";
import { MessageHandler } from "@server/routers/ws";
import { Newt } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm";
import { eq, lt, isNull, and, or, ne, not } from "drizzle-orm";
import logger from "@server/logger";
import { sendNewtSyncMessage } from "./sync";
import { recordPing } from "./pingAccumulator";
// Track if the offline checker interval is running
let offlineCheckerInterval: NodeJS.Timeout | null = null;
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
const OFFLINE_THRESHOLD_BANDWIDTH_MS = 8 * 60 * 1000; // 8 minutes
/**
* Starts the background interval that checks for newt sites that haven't
@@ -55,7 +60,9 @@ export const startNewtOfflineChecker = (): void => {
// Backward-compatibility check: if the newt still has an
// active WebSocket connection (older clients that don't send
// pings), keep the site online.
const isConnected = await hasActiveConnections(staleSite.newtId);
const isConnected = await hasActiveConnections(
staleSite.newtId
);
if (isConnected) {
logger.debug(
`Newt ${staleSite.newtId} has not pinged recently but is still connected via WebSocket — keeping site ${staleSite.siteId} online`
@@ -71,6 +78,83 @@ export const startNewtOfflineChecker = (): void => {
.update(sites)
.set({ online: false })
.where(eq(sites.siteId, staleSite.siteId));
const healthChecksOnSite = await db
.select()
.from(targetHealthCheck)
.innerJoin(
targets,
eq(targets.targetId, targetHealthCheck.targetId)
)
.innerJoin(sites, eq(sites.siteId, targets.siteId))
.where(eq(sites.siteId, staleSite.siteId));
for (const healthCheck of healthChecksOnSite) {
logger.info(
`Marking health check ${healthCheck.targetHealthCheck.targetHealthCheckId} offline due to site ${staleSite.siteId} being marked offline`
);
await db
.update(targetHealthCheck)
.set({ hcHealth: "unknown" })
.where(
eq(
targetHealthCheck.targetHealthCheckId,
healthCheck.targetHealthCheck
.targetHealthCheckId
)
);
}
}
// this part only effects self hosted. Its not efficient but we dont expect people to have very many wireguard sites
// select all of the wireguard sites to evaluate if they need to be offline due to the last bandwidth update
const allWireguardSites = await db
.select({
siteId: sites.siteId,
online: sites.online,
lastBandwidthUpdate: sites.lastBandwidthUpdate
})
.from(sites)
.where(
and(
eq(sites.type, "wireguard"),
not(isNull(sites.lastBandwidthUpdate))
)
);
const wireguardOfflineThreshold = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_BANDWIDTH_MS) / 1000
);
// loop over each one. If its offline and there is a new update then mark it online. If its online and there is no update then mark it offline
for (const site of allWireguardSites) {
const lastBandwidthUpdate =
new Date(site.lastBandwidthUpdate!).getTime() / 1000;
if (
lastBandwidthUpdate < wireguardOfflineThreshold &&
site.online
) {
logger.info(
`Marking wireguard site ${site.siteId} offline: no bandwidth update in over ${OFFLINE_THRESHOLD_BANDWIDTH_MS / 60000} minutes`
);
await db
.update(sites)
.set({ online: false })
.where(eq(sites.siteId, site.siteId));
} else if (
lastBandwidthUpdate >= wireguardOfflineThreshold &&
!site.online
) {
logger.info(
`Marking wireguard site ${site.siteId} online: recent bandwidth update`
);
await db
.update(sites)
.set({ online: true })
.where(eq(sites.siteId, site.siteId));
}
}
} catch (error) {
logger.error("Error in newt offline checker interval", { error });
@@ -114,18 +198,12 @@ export const handleNewtPingMessage: MessageHandler = async (context) => {
return;
}
try {
// Mark the site as online and record the ping timestamp.
await db
.update(sites)
.set({
online: true,
lastPing: Math.floor(Date.now() / 1000)
})
.where(eq(sites.siteId, newt.siteId));
} catch (error) {
logger.error("Error updating online state on newt ping", { error });
}
// Record the ping in memory; it will be flushed to the database
// periodically by the ping accumulator (every ~10s) in a single
// batched UPDATE instead of one query per ping. This prevents
// connection pool exhaustion under load, especially with
// cross-region latency to the database.
recordPing(newt.siteId);
// Check config version and sync if stale.
const configVersion = await getClientConfigVersion(newt.newtId);

View File

@@ -33,7 +33,7 @@ export const handleNewtPingRequestMessage: MessageHandler = async (context) => {
return;
}
const { noCloud } = message.data;
const { noCloud, chainId } = message.data;
const exitNodesList = await listExitNodes(
site.orgId,
@@ -98,7 +98,8 @@ export const handleNewtPingRequestMessage: MessageHandler = async (context) => {
message: {
type: "newt/ping/exitNodes",
data: {
exitNodes: filteredExitNodes
exitNodes: filteredExitNodes,
chainId: chainId
}
},
broadcast: false, // Send to all clients

View File

@@ -43,7 +43,7 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
const siteId = newt.siteId;
const { publicKey, pingResults, newtVersion, backwardsCompatible } =
const { publicKey, pingResults, newtVersion, backwardsCompatible, chainId } =
message.data;
if (!publicKey) {
logger.warn("Public key not provided");
@@ -211,7 +211,8 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
udp: udpTargets,
tcp: tcpTargets
},
healthCheckTargets: validHealthCheckTargets
healthCheckTargets: validHealthCheckTargets,
chainId: chainId
}
},
options: {

View File

@@ -8,3 +8,5 @@ export * from "./handleNewtPingRequestMessage";
export * from "./handleApplyBlueprintMessage";
export * from "./handleNewtPingMessage";
export * from "./handleNewtDisconnectingMessage";
export * from "./handleConnectionLogMessage";
export * from "./registerNewt";

View File

@@ -0,0 +1,392 @@
import { db } from "@server/db";
import { sites, clients, olms } from "@server/db";
import { inArray } from "drizzle-orm";
import logger from "@server/logger";
/**
* Ping Accumulator
*
* Instead of writing to the database on every single newt/olm ping (which
* causes pool exhaustion under load, especially with cross-region latency),
* we accumulate pings in memory and flush them to the database periodically
* in a single batch.
*
* This is the same pattern used for bandwidth flushing in
* receiveBandwidth.ts and handleReceiveBandwidthMessage.ts.
*
* Supports two kinds of pings:
* - **Site pings** (from newts): update `sites.online` and `sites.lastPing`
* - **Client pings** (from OLMs): update `clients.online`, `clients.lastPing`,
* `clients.archived`, and optionally reset `olms.archived`
*/
const FLUSH_INTERVAL_MS = 10_000; // Flush every 10 seconds
const MAX_RETRIES = 5;
const BASE_DELAY_MS = 50;
// ── Site (newt) pings ──────────────────────────────────────────────────
// Map of siteId -> latest ping timestamp (unix seconds)
const pendingSitePings: Map<number, number> = new Map();
// ── Client (OLM) pings ────────────────────────────────────────────────
// Map of clientId -> latest ping timestamp (unix seconds)
const pendingClientPings: Map<number, number> = new Map();
// Set of olmIds whose `archived` flag should be reset to false
const pendingOlmArchiveResets: Set<string> = new Set();
let flushTimer: NodeJS.Timeout | null = null;
/**
* Guard that prevents two flush cycles from running concurrently.
* setInterval does not await async callbacks, so without this a slow flush
* (e.g. due to DB latency) would overlap with the next scheduled cycle and
* the two concurrent bulk UPDATEs would deadlock each other.
*/
let isFlushing = false;
// ── Public API ─────────────────────────────────────────────────────────
/**
* Record a ping for a newt site. This does NOT write to the database
* immediately. Instead it stores the latest ping timestamp in memory,
* to be flushed periodically by the background timer.
*/
export function recordSitePing(siteId: number): void {
const now = Math.floor(Date.now() / 1000);
pendingSitePings.set(siteId, now);
}
/** @deprecated Use `recordSitePing` instead. Alias kept for existing call-sites. */
export const recordPing = recordSitePing;
/**
* Record a ping for an OLM client. Batches the `clients` table update
* (`online`, `lastPing`, `archived`) and, when `olmArchived` is true,
* also queues an `olms` table update to clear the archived flag.
*/
export function recordClientPing(
clientId: number,
olmId: string,
olmArchived: boolean
): void {
const now = Math.floor(Date.now() / 1000);
pendingClientPings.set(clientId, now);
if (olmArchived) {
pendingOlmArchiveResets.add(olmId);
}
}
// ── Flush Logic ────────────────────────────────────────────────────────
/**
* Flush all accumulated site pings to the database.
*
* Each batch of up to BATCH_SIZE rows is written with a **single** UPDATE
* statement. We use the maximum timestamp across the batch so that `lastPing`
* reflects the most recent ping seen for any site in the group. This avoids
* the multi-statement transaction that previously created additional
* row-lock ordering hazards.
*/
async function flushSitePingsToDb(): Promise<void> {
if (pendingSitePings.size === 0) {
return;
}
// Snapshot and clear so new pings arriving during the flush go into a
// fresh map for the next cycle.
const pingsToFlush = new Map(pendingSitePings);
pendingSitePings.clear();
const entries = Array.from(pingsToFlush.entries());
const BATCH_SIZE = 50;
for (let i = 0; i < entries.length; i += BATCH_SIZE) {
const batch = entries.slice(i, i + BATCH_SIZE);
// Use the latest timestamp in the batch so that `lastPing` always
// moves forward. Using a single timestamp for the whole batch means
// we only ever need one UPDATE statement (no transaction).
const maxTimestamp = Math.max(...batch.map(([, ts]) => ts));
const siteIds = batch.map(([id]) => id);
try {
await withRetry(async () => {
await db
.update(sites)
.set({
online: true,
lastPing: maxTimestamp
})
.where(inArray(sites.siteId, siteIds));
}, "flushSitePingsToDb");
} catch (error) {
logger.error(
`Failed to flush site ping batch (${batch.length} sites), re-queuing for next cycle`,
{ error }
);
// Re-queue only if the preserved timestamp is newer than any
// update that may have landed since we snapshotted.
for (const [siteId, timestamp] of batch) {
const existing = pendingSitePings.get(siteId);
if (!existing || existing < timestamp) {
pendingSitePings.set(siteId, timestamp);
}
}
}
}
}
/**
* Flush all accumulated client (OLM) pings to the database.
*
* Same single-UPDATE-per-batch approach as `flushSitePingsToDb`.
*/
async function flushClientPingsToDb(): Promise<void> {
if (pendingClientPings.size === 0 && pendingOlmArchiveResets.size === 0) {
return;
}
// Snapshot and clear
const pingsToFlush = new Map(pendingClientPings);
pendingClientPings.clear();
const olmResetsToFlush = new Set(pendingOlmArchiveResets);
pendingOlmArchiveResets.clear();
// ── Flush client pings ─────────────────────────────────────────────
if (pingsToFlush.size > 0) {
const entries = Array.from(pingsToFlush.entries());
const BATCH_SIZE = 50;
for (let i = 0; i < entries.length; i += BATCH_SIZE) {
const batch = entries.slice(i, i + BATCH_SIZE);
const maxTimestamp = Math.max(...batch.map(([, ts]) => ts));
const clientIds = batch.map(([id]) => id);
try {
await withRetry(async () => {
await db
.update(clients)
.set({
lastPing: maxTimestamp,
online: true,
archived: false
})
.where(inArray(clients.clientId, clientIds));
}, "flushClientPingsToDb");
} catch (error) {
logger.error(
`Failed to flush client ping batch (${batch.length} clients), re-queuing for next cycle`,
{ error }
);
for (const [clientId, timestamp] of batch) {
const existing = pendingClientPings.get(clientId);
if (!existing || existing < timestamp) {
pendingClientPings.set(clientId, timestamp);
}
}
}
}
}
// ── Flush OLM archive resets ───────────────────────────────────────
if (olmResetsToFlush.size > 0) {
const olmIds = Array.from(olmResetsToFlush).sort();
const BATCH_SIZE = 50;
for (let i = 0; i < olmIds.length; i += BATCH_SIZE) {
const batch = olmIds.slice(i, i + BATCH_SIZE);
try {
await withRetry(async () => {
await db
.update(olms)
.set({ archived: false })
.where(inArray(olms.olmId, batch));
}, "flushOlmArchiveResets");
} catch (error) {
logger.error(
`Failed to flush OLM archive reset batch (${batch.length} olms), re-queuing for next cycle`,
{ error }
);
for (const olmId of batch) {
pendingOlmArchiveResets.add(olmId);
}
}
}
}
}
/**
* Flush everything — called by the interval timer and during shutdown.
*/
export async function flushPingsToDb(): Promise<void> {
await flushSitePingsToDb();
await flushClientPingsToDb();
}
// ── Retry / Error Helpers ──────────────────────────────────────────────
/**
* Simple retry wrapper with exponential backoff for transient errors
* (deadlocks, connection timeouts, unexpected disconnects).
*
* PostgreSQL deadlocks (40P01) are always safe to retry: the database
* guarantees exactly one winner per deadlock pair, so the loser just needs
* to try again. MAX_RETRIES is intentionally higher than typical connection
* retry budgets to give deadlock victims enough chances to succeed.
*/
async function withRetry<T>(
operation: () => Promise<T>,
context: string
): Promise<T> {
let attempt = 0;
while (true) {
try {
return await operation();
} catch (error: any) {
if (isTransientError(error) && attempt < MAX_RETRIES) {
attempt++;
const baseDelay = Math.pow(2, attempt - 1) * BASE_DELAY_MS;
const jitter = Math.random() * baseDelay;
const delay = baseDelay + jitter;
logger.warn(
`Transient DB error in ${context}, retrying attempt ${attempt}/${MAX_RETRIES} after ${delay.toFixed(0)}ms`,
{ code: error?.code ?? error?.cause?.code }
);
await new Promise((resolve) => setTimeout(resolve, delay));
continue;
}
throw error;
}
}
}
/**
* Detect transient errors that are safe to retry.
*/
function isTransientError(error: any): boolean {
if (!error) return false;
const message = (error.message || "").toLowerCase();
const causeMessage = (error.cause?.message || "").toLowerCase();
const code = error.code || error.cause?.code || "";
// Connection timeout / terminated
if (
message.includes("connection timeout") ||
message.includes("connection terminated") ||
message.includes("timeout exceeded when trying to connect") ||
causeMessage.includes("connection terminated unexpectedly") ||
causeMessage.includes("connection timeout")
) {
return true;
}
// PostgreSQL deadlock detected — always safe to retry (one winner guaranteed)
if (code === "40P01" || message.includes("deadlock")) {
return true;
}
// PostgreSQL serialization failure
if (code === "40001") {
return true;
}
// ECONNRESET, ECONNREFUSED, EPIPE, ETIMEDOUT
if (
code === "ECONNRESET" ||
code === "ECONNREFUSED" ||
code === "EPIPE" ||
code === "ETIMEDOUT"
) {
return true;
}
return false;
}
// ── Lifecycle ──────────────────────────────────────────────────────────
/**
* Start the background flush timer. Call this once at server startup.
*/
export function startPingAccumulator(): void {
if (flushTimer) {
return; // Already running
}
flushTimer = setInterval(async () => {
// Skip this tick if the previous flush is still in progress.
// setInterval does not await async callbacks, so without this guard
// two flush cycles can run concurrently and deadlock each other on
// overlapping bulk UPDATE statements.
if (isFlushing) {
logger.debug(
"Ping accumulator: previous flush still in progress, skipping cycle"
);
return;
}
isFlushing = true;
try {
await flushPingsToDb();
} catch (error) {
logger.error("Unhandled error in ping accumulator flush", {
error
});
} finally {
isFlushing = false;
}
}, FLUSH_INTERVAL_MS);
// Don't prevent the process from exiting
flushTimer.unref();
logger.info(
`Ping accumulator started (flush interval: ${FLUSH_INTERVAL_MS}ms)`
);
}
/**
* Stop the background flush timer and perform a final flush.
* Call this during graceful shutdown.
*/
export async function stopPingAccumulator(): Promise<void> {
if (flushTimer) {
clearInterval(flushTimer);
flushTimer = null;
}
// Final flush to persist any remaining pings.
// Wait for any in-progress flush to finish first so we don't race.
if (isFlushing) {
logger.debug(
"Ping accumulator: waiting for in-progress flush before stopping…"
);
await new Promise<void>((resolve) => {
const poll = setInterval(() => {
if (!isFlushing) {
clearInterval(poll);
resolve();
}
}, 50);
});
}
try {
await flushPingsToDb();
} catch (error) {
logger.error("Error during final ping accumulator flush", { error });
}
logger.info("Ping accumulator stopped");
}
/**
* Get the number of pending (unflushed) pings. Useful for monitoring.
*/
export function getPendingPingCount(): number {
return pendingSitePings.size + pendingClientPings.size;
}

View File

@@ -0,0 +1,289 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import {
siteProvisioningKeys,
siteProvisioningKeyOrg,
newts,
orgs,
roles,
roleSites,
sites
} from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { eq, and, sql } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import { verifyPassword, hashPassword } from "@server/auth/password";
import {
generateId,
generateIdFromEntropySize
} from "@server/auth/sessions/app";
import { getUniqueSiteName } from "@server/db/names";
import moment from "moment";
import { build } from "@server/build";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
import { INSPECT_MAX_BYTES } from "buffer";
import { getNextAvailableClientSubnet } from "@server/lib/ip";
const bodySchema = z.object({
provisioningKey: z.string().nonempty(),
name: z.string().optional()
});
export type RegisterNewtBody = z.infer<typeof bodySchema>;
export type RegisterNewtResponse = {
newtId: string;
secret: string;
};
export async function registerNewt(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { provisioningKey, name } = parsedBody.data;
// Keys are in the format "siteProvisioningKeyId.secret"
const dotIndex = provisioningKey.indexOf(".");
if (dotIndex === -1) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid provisioning key format"
)
);
}
const provisioningKeyId = provisioningKey.substring(0, dotIndex);
const provisioningKeySecret = provisioningKey.substring(dotIndex + 1);
// Look up the provisioning key by ID, joining to get the orgId
const [keyRecord] = await db
.select({
siteProvisioningKeyId:
siteProvisioningKeys.siteProvisioningKeyId,
siteProvisioningKeyHash:
siteProvisioningKeys.siteProvisioningKeyHash,
orgId: siteProvisioningKeyOrg.orgId,
maxBatchSize: siteProvisioningKeys.maxBatchSize,
numUsed: siteProvisioningKeys.numUsed,
validUntil: siteProvisioningKeys.validUntil,
approveNewSites: siteProvisioningKeys.approveNewSites,
})
.from(siteProvisioningKeys)
.innerJoin(
siteProvisioningKeyOrg,
eq(
siteProvisioningKeys.siteProvisioningKeyId,
siteProvisioningKeyOrg.siteProvisioningKeyId
)
)
.where(
eq(
siteProvisioningKeys.siteProvisioningKeyId,
provisioningKeyId
)
)
.limit(1);
if (!keyRecord) {
return next(
createHttpError(
HttpCode.UNAUTHORIZED,
"Invalid provisioning key"
)
);
}
// Verify the secret portion against the stored hash
const validSecret = await verifyPassword(
provisioningKeySecret,
keyRecord.siteProvisioningKeyHash
);
if (!validSecret) {
return next(
createHttpError(
HttpCode.UNAUTHORIZED,
"Invalid provisioning key"
)
);
}
if (keyRecord.maxBatchSize && keyRecord.numUsed >= keyRecord.maxBatchSize) {
return next(
createHttpError(
HttpCode.UNAUTHORIZED,
"Provisioning key has reached its maximum usage"
)
);
}
if (keyRecord.validUntil && new Date(keyRecord.validUntil) < new Date()) {
return next(
createHttpError(
HttpCode.UNAUTHORIZED,
"Provisioning key has expired"
)
);
}
const { orgId } = keyRecord;
// Verify the org exists
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
if (!org) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
);
}
if (!org.subnet) {
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Organization subnet not found")
);
}
// SaaS billing check
if (build == "saas") {
const usage = await usageService.getUsage(orgId, FeatureId.SITES);
if (!usage) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"No usage data found for this organization"
)
);
}
const rejectSites = await usageService.checkLimitSet(
orgId,
FeatureId.SITES,
{
...usage,
instantaneousValue: (usage.instantaneousValue || 0) + 1
}
);
if (rejectSites) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Site limit exceeded. Please upgrade your plan."
)
);
}
}
const niceId = await getUniqueSiteName(orgId);
const newtId = generateId(15);
const newtSecret = generateIdFromEntropySize(25);
const secretHash = await hashPassword(newtSecret);
let newSiteId: number | undefined;
await db.transaction(async (trx) => {
const newClientAddress = await getNextAvailableClientSubnet(orgId);
if (!newClientAddress) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"No available subnet found"
)
);
}
let clientAddress = newClientAddress.split("/")[0];
clientAddress = `${clientAddress}/${org.subnet!.split("/")[1]}`; // we want the block size of the whole org
// Create the site (type "newt", name = niceId)
const [newSite] = await trx
.insert(sites)
.values({
orgId,
name: name || niceId,
niceId,
address: clientAddress,
type: "newt",
dockerSocketEnabled: true,
status: keyRecord.approveNewSites ? "approved" : "pending",
})
.returning();
newSiteId = newSite.siteId;
// Grant admin role access to the new site
const [adminRole] = await trx
.select()
.from(roles)
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
.limit(1);
if (!adminRole) {
throw new Error(`Admin role not found for org ${orgId}`);
}
await trx.insert(roleSites).values({
roleId: adminRole.roleId,
siteId: newSite.siteId
});
// Create the newt for this site
await trx.insert(newts).values({
newtId,
secretHash,
siteId: newSite.siteId,
dateCreated: moment().toISOString()
});
// Consume the provisioning key — cascade removes siteProvisioningKeyOrg
await trx
.update(siteProvisioningKeys)
.set({
lastUsed: moment().toISOString(),
numUsed: sql`${siteProvisioningKeys.numUsed} + 1`
})
.where(
eq(
siteProvisioningKeys.siteProvisioningKeyId,
provisioningKeyId
)
);
await usageService.add(orgId, FeatureId.SITES, 1, trx);
});
logger.info(
`Provisioned new site (ID: ${newSiteId}) and newt (ID: ${newtId}) for org ${orgId} via provisioning key ${provisioningKeyId}`
);
return response<RegisterNewtResponse>(res, {
data: {
newtId,
secret: newtSecret
},
success: true,
error: false,
message: "Newt registered successfully",
status: HttpCode.CREATED
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -46,7 +46,7 @@ export async function createNewt(
const { newtId, secret } = parsedBody.data;
if (req.user && !req.userOrgRoleId) {
if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);

View File

@@ -8,7 +8,7 @@ import {
ExitNode,
exitNodes,
sites,
clientSitesAssociationsCache
clientSitesAssociationsCache,
} from "@server/db";
import { olms } from "@server/db";
import HttpCode from "@server/types/HttpCode";
@@ -20,8 +20,10 @@ import { z } from "zod";
import { fromError } from "zod-validation-error";
import {
createOlmSession,
validateOlmSessionToken
validateOlmSessionToken,
EXPIRES
} from "@server/auth/sessions/olm";
import { getOrCreateCachedToken } from "#dynamic/lib/tokenCache";
import { verifyPassword } from "@server/auth/password";
import logger from "@server/logger";
import config from "@server/lib/config";
@@ -132,8 +134,19 @@ export async function getOlmToken(
logger.debug("Creating new olm session token");
const resToken = generateSessionToken();
await createOlmSession(resToken, existingOlm.olmId);
// Return a cached token if one exists to prevent thundering herd on
// simultaneous restarts; falls back to creating a fresh session when
// Redis is unavailable or the cache has expired.
const resToken = await getOrCreateCachedToken(
`olm:token_cache:${existingOlm.olmId}`,
config.getRawConfig().server.secret!,
Math.floor(EXPIRES / 1000),
async () => {
const token = generateSessionToken();
await createOlmSession(token, existingOlm.olmId);
return token;
}
);
let clientIdToUse;
if (orgId) {

View File

@@ -3,6 +3,7 @@ import { db } from "@server/db";
import { MessageHandler } from "@server/routers/ws";
import { clients, olms, Olm } from "@server/db";
import { eq, lt, isNull, and, or } from "drizzle-orm";
import { recordClientPing } from "@server/routers/newt/pingAccumulator";
import logger from "@server/logger";
import { validateSessionToken } from "@server/auth/sessions/app";
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
@@ -201,22 +202,12 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
await sendOlmSyncMessage(olm, client);
}
// Update the client's last ping timestamp
await db
.update(clients)
.set({
lastPing: Math.floor(Date.now() / 1000),
online: true,
archived: false
})
.where(eq(clients.clientId, olm.clientId));
if (olm.archived) {
await db
.update(olms)
.set({ archived: false })
.where(eq(olms.olmId, olm.olmId));
}
// Record the ping in memory; it will be flushed to the database
// periodically by the ping accumulator (every ~10s) in a single
// batched UPDATE instead of one query per ping. This prevents
// connection pool exhaustion under load, especially with
// cross-region latency to the database.
recordClientPing(olm.clientId, olm.olmId, !!olm.archived);
} catch (error) {
logger.error("Error handling ping message", { error });
}

View File

@@ -20,6 +20,7 @@ import { handleFingerprintInsertion } from "./fingerprintingUtils";
import { Alias } from "@server/lib/ip";
import { build } from "@server/build";
import { canCompress } from "@server/lib/clientVersionChecks";
import config from "@server/lib/config";
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
logger.info("Handling register olm message!");
@@ -41,7 +42,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
orgId,
userToken,
fingerprint,
postures
postures,
chainId
} = message.data;
if (!olm.clientId) {
@@ -273,7 +275,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
// TODO: I still think there is a better way to do this rather than locking it out here but ???
if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) {
logger.warn(
"Client last hole punch is too old and we have sites to send; skipping this register"
`Client last hole punch is too old and we have sites to send; skipping this register. The client is failing to hole punch and identify its network address with the server. Can the client reach the server on UDP port ${config.getRawConfig().gerbil.clients_start_port}?`
);
return;
}
@@ -293,7 +295,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
data: {
sites: siteConfigurations,
tunnelIP: client.subnet,
utilitySubnet: org.utilitySubnet
utilitySubnet: org.utilitySubnet,
chainId: chainId
}
},
options: {

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, idp, idpOidcConfig } from "@server/db";
import { roles, userOrgs, users } from "@server/db";
import { roles, userOrgRoles, userOrgs, users } from "@server/db";
import { and, eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
@@ -14,7 +14,7 @@ import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
import { CheckOrgAccessPolicyResult } from "@server/lib/checkOrgAccessPolicy";
async function queryUser(orgId: string, userId: string) {
const [user] = await db
const [userRow] = await db
.select({
orgId: userOrgs.orgId,
userId: users.userId,
@@ -22,10 +22,7 @@ async function queryUser(orgId: string, userId: string) {
username: users.username,
name: users.name,
type: users.type,
roleId: userOrgs.roleId,
roleName: roles.name,
isOwner: userOrgs.isOwner,
isAdmin: roles.isAdmin,
twoFactorEnabled: users.twoFactorEnabled,
autoProvisioned: userOrgs.autoProvisioned,
idpId: users.idpId,
@@ -35,13 +32,40 @@ async function queryUser(orgId: string, userId: string) {
idpAutoProvision: idp.autoProvision
})
.from(userOrgs)
.leftJoin(roles, eq(userOrgs.roleId, roles.roleId))
.leftJoin(users, eq(userOrgs.userId, users.userId))
.leftJoin(idp, eq(users.idpId, idp.idpId))
.leftJoin(idpOidcConfig, eq(idp.idpId, idpOidcConfig.idpId))
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)))
.limit(1);
return user;
if (!userRow) return undefined;
const roleRows = await db
.select({
roleId: userOrgRoles.roleId,
roleName: roles.name,
isAdmin: roles.isAdmin
})
.from(userOrgRoles)
.leftJoin(roles, eq(userOrgRoles.roleId, roles.roleId))
.where(
and(
eq(userOrgRoles.userId, userId),
eq(userOrgRoles.orgId, orgId)
)
);
const isAdmin = roleRows.some((r) => r.isAdmin);
return {
...userRow,
isAdmin,
roleIds: roleRows.map((r) => r.roleId),
roles: roleRows.map((r) => ({
roleId: r.roleId,
name: r.roleName ?? ""
}))
};
}
export type CheckOrgUserAccessResponse = CheckOrgAccessPolicyResult;

View File

@@ -9,6 +9,7 @@ import {
orgs,
roleActions,
roles,
userOrgRoles,
userOrgs,
users,
actions
@@ -312,9 +313,13 @@ export async function createOrg(
await trx.insert(userOrgs).values({
userId: req.user!.userId,
orgId: newOrg[0].orgId,
roleId: roleId,
isOwner: true
});
await trx.insert(userOrgRoles).values({
userId: req.user!.userId,
orgId: newOrg[0].orgId,
roleId
});
ownerUserId = req.user!.userId;
} else {
// if org created by root api key, set the server admin as the owner
@@ -332,9 +337,13 @@ export async function createOrg(
await trx.insert(userOrgs).values({
userId: serverAdmin.userId,
orgId: newOrg[0].orgId,
roleId: roleId,
isOwner: true
});
await trx.insert(userOrgRoles).values({
userId: serverAdmin.userId,
orgId: newOrg[0].orgId,
roleId
});
ownerUserId = serverAdmin.userId;
}

View File

@@ -117,20 +117,26 @@ export async function getOrgOverview(
.from(userOrgs)
.where(eq(userOrgs.orgId, orgId));
const [role] = await db
.select()
.from(roles)
.where(eq(roles.roleId, req.userOrg.roleId));
const roleIds = req.userOrgRoleIds ?? [];
const roleRows =
roleIds.length > 0
? await db
.select({ name: roles.name, isAdmin: roles.isAdmin })
.from(roles)
.where(inArray(roles.roleId, roleIds))
: [];
const userRoleName = roleRows.map((r) => r.name ?? "").join(", ") ?? "";
const isAdmin = roleRows.some((r) => r.isAdmin === true);
return response<GetOrgOverviewResponse>(res, {
data: {
orgName: org[0].name,
orgId: org[0].orgId,
userRoleName: role.name,
userRoleName,
numSites,
numUsers,
numResources,
isAdmin: role.isAdmin || false,
isAdmin,
isOwner: req.userOrg?.isOwner || false
},
success: true,

View File

@@ -8,3 +8,4 @@ export * from "./getOrgOverview";
export * from "./listOrgs";
export * from "./pickOrgDefaults";
export * from "./checkOrgUserAccess";
export * from "./resetOrgBandwidth";

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, roles } from "@server/db";
import { Org, orgs, userOrgs } from "@server/db";
import { Org, orgs, userOrgRoles, userOrgs } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -82,10 +82,7 @@ export async function listUserOrgs(
const { userId } = parsedParams.data;
const userOrganizations = await db
.select({
orgId: userOrgs.orgId,
roleId: userOrgs.roleId
})
.select({ orgId: userOrgs.orgId })
.from(userOrgs)
.where(eq(userOrgs.userId, userId));
@@ -116,10 +113,27 @@ export async function listUserOrgs(
userOrgs,
and(eq(userOrgs.orgId, orgs.orgId), eq(userOrgs.userId, userId))
)
.leftJoin(roles, eq(userOrgs.roleId, roles.roleId))
.limit(limit)
.offset(offset);
const roleRows = await db
.select({
orgId: userOrgRoles.orgId,
isAdmin: roles.isAdmin
})
.from(userOrgRoles)
.leftJoin(roles, eq(userOrgRoles.roleId, roles.roleId))
.where(
and(
eq(userOrgRoles.userId, userId),
inArray(userOrgRoles.orgId, userOrgIds)
)
);
const orgHasAdmin = new Set(
roleRows.filter((r) => r.isAdmin).map((r) => r.orgId)
);
const totalCountResult = await db
.select({ count: sql<number>`cast(count(*) as integer)` })
.from(orgs)
@@ -133,8 +147,8 @@ export async function listUserOrgs(
if (val.userOrgs && val.userOrgs.isOwner) {
res.isOwner = val.userOrgs.isOwner;
}
if (val.roles && val.roles.isAdmin) {
res.isAdmin = val.roles.isAdmin;
if (val.orgs && orgHasAdmin.has(val.orgs.orgId)) {
res.isAdmin = true;
}
if (val.userOrgs?.isOwner && val.orgs?.isBillingOrg) {
res.isPrimaryOrg = val.orgs.isBillingOrg;

View File

@@ -0,0 +1,83 @@
import { NextFunction, Request, Response } from "express";
import { z } from "zod";
import { db, sites } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const resetOrgBandwidthParamsSchema = z.strictObject({
orgId: z.string()
});
registry.registerPath({
method: "post",
path: "/org/{orgId}/reset-bandwidth",
description: "Reset all sites in selected organization bandwidth counters.",
tags: [OpenAPITags.Org, OpenAPITags.Site],
request: {
params: resetOrgBandwidthParamsSchema
},
responses: {}
});
export async function resetOrgBandwidth(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = resetOrgBandwidthParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const { orgId } = parsedParams.data;
const [site] = await db
.select({ siteId: sites.siteId })
.from(sites)
.where(eq(sites.orgId, orgId))
.limit(1);
if (!site) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`No sites found in org ${orgId}`
)
);
}
await db
.update(sites)
.set({
megabytesIn: 0,
megabytesOut: 0
})
.where(eq(sites.orgId, orgId));
return response(res, {
data: {},
success: true,
error: false,
message: "Sites bandwidth reset successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -34,6 +34,10 @@ const updateOrgBodySchema = z
.min(build === "saas" ? 0 : -1)
.optional(),
settingsLogRetentionDaysAction: z
.number()
.min(build === "saas" ? 0 : -1)
.optional(),
settingsLogRetentionDaysConnection: z
.number()
.min(build === "saas" ? 0 : -1)
.optional()
@@ -164,6 +168,17 @@ export async function updateOrg(
)
);
}
if (
parsedBody.data.settingsLogRetentionDaysConnection !== undefined &&
parsedBody.data.settingsLogRetentionDaysConnection > maxRetentionDays
) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
`You are not allowed to set log retention days greater than ${maxRetentionDays} with your current subscription`
)
);
}
}
}
@@ -179,7 +194,9 @@ export async function updateOrg(
settingsLogRetentionDaysAccess:
parsedBody.data.settingsLogRetentionDaysAccess,
settingsLogRetentionDaysAction:
parsedBody.data.settingsLogRetentionDaysAction
parsedBody.data.settingsLogRetentionDaysAction,
settingsLogRetentionDaysConnection:
parsedBody.data.settingsLogRetentionDaysConnection
})
.where(eq(orgs.orgId, orgId))
.returning();
@@ -197,6 +214,7 @@ export async function updateOrg(
await cache.del(`org_${orgId}_retentionDays`);
await cache.del(`org_${orgId}_actionDays`);
await cache.del(`org_${orgId}_accessDays`);
await cache.del(`org_${orgId}_connectionDays`);
return response(res, {
data: updatedOrg[0],

View File

@@ -112,7 +112,7 @@ export async function createResource(
const { orgId } = parsedParams.data;
if (req.user && !req.userOrgRoleId) {
if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);
@@ -292,7 +292,7 @@ async function createHttpResource(
resourceId: newResource[0].resourceId
});
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
if (req.user && !req.userOrgRoleIds?.includes(adminRole[0].roleId)) {
// make sure the user can access the resource
await trx.insert(userResources).values({
userId: req.user?.userId!,
@@ -385,7 +385,7 @@ async function createRawResource(
resourceId: newResource[0].resourceId
});
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
if (req.user && !req.userOrgRoleIds?.includes(adminRole[0].roleId)) {
// make sure the user can access the resource
await trx.insert(userResources).values({
userId: req.user?.userId!,

View File

@@ -14,10 +14,11 @@ import {
isValidUrlGlobPattern
} from "@server/lib/validators";
import { OpenAPITags, registry } from "@server/openApi";
import { isValidRegionId } from "@server/db/regions";
const createResourceRuleSchema = z.strictObject({
action: z.enum(["ACCEPT", "DROP", "PASS"]),
match: z.enum(["CIDR", "IP", "PATH", "COUNTRY", "ASN"]),
match: z.enum(["CIDR", "IP", "PATH", "COUNTRY", "ASN", "REGION"]),
value: z.string().min(1),
priority: z.int(),
enabled: z.boolean().optional()
@@ -126,6 +127,15 @@ export async function createResourceRule(
)
);
}
} else if (match === "REGION") {
if (!isValidRegionId(value)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid region ID provided"
)
);
}
}
// Create the new resource rule

View File

@@ -1,15 +1,14 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { Resource, resources, sites } from "@server/db";
import { eq, and } from "drizzle-orm";
import { db, resources } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import stoi from "@server/lib/stoi";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import HttpCode from "@server/types/HttpCode";
import { and, eq } from "drizzle-orm";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
const getResourceSchema = z.strictObject({
resourceId: z

View File

@@ -5,6 +5,7 @@ import {
resources,
userResources,
roleResources,
userOrgRoles,
userOrgs,
resourcePassword,
resourcePincode,
@@ -32,22 +33,29 @@ export async function getUserResources(
);
}
// First get the user's role in the organization
const userOrgResult = await db
.select({
roleId: userOrgs.roleId
})
// Check user is in organization and get their role IDs
const [userOrg] = await db
.select()
.from(userOrgs)
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)))
.limit(1);
if (userOrgResult.length === 0) {
if (!userOrg) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User not in organization")
);
}
const userRoleId = userOrgResult[0].roleId;
const userRoleIds = await db
.select({ roleId: userOrgRoles.roleId })
.from(userOrgRoles)
.where(
and(
eq(userOrgRoles.userId, userId),
eq(userOrgRoles.orgId, orgId)
)
)
.then((rows) => rows.map((r) => r.roleId));
// Get resources accessible through direct assignment or role assignment
const directResourcesQuery = db
@@ -55,20 +63,28 @@ export async function getUserResources(
.from(userResources)
.where(eq(userResources.userId, userId));
const roleResourcesQuery = db
.select({ resourceId: roleResources.resourceId })
.from(roleResources)
.where(eq(roleResources.roleId, userRoleId));
const roleResourcesQuery =
userRoleIds.length > 0
? db
.select({ resourceId: roleResources.resourceId })
.from(roleResources)
.where(inArray(roleResources.roleId, userRoleIds))
: Promise.resolve([]);
const directSiteResourcesQuery = db
.select({ siteResourceId: userSiteResources.siteResourceId })
.from(userSiteResources)
.where(eq(userSiteResources.userId, userId));
const roleSiteResourcesQuery = db
.select({ siteResourceId: roleSiteResources.siteResourceId })
.from(roleSiteResources)
.where(eq(roleSiteResources.roleId, userRoleId));
const roleSiteResourcesQuery =
userRoleIds.length > 0
? db
.select({
siteResourceId: roleSiteResources.siteResourceId
})
.from(roleSiteResources)
.where(inArray(roleSiteResources.roleId, userRoleIds))
: Promise.resolve([]);
const [directResources, roleResourceResults, directSiteResourceResults, roleSiteResourceResults] = await Promise.all([
directResourcesQuery,

View File

@@ -305,7 +305,7 @@ export async function listResources(
.where(
or(
eq(userResources.userId, req.user!.userId),
eq(roleResources.roleId, req.userOrgRoleId!)
inArray(roleResources.roleId, req.userOrgRoleIds!)
)
);
} else {

View File

@@ -14,6 +14,7 @@ import {
isValidUrlGlobPattern
} from "@server/lib/validators";
import { OpenAPITags, registry } from "@server/openApi";
import { isValidRegionId } from "@server/db/regions";
// Define Zod schema for request parameters validation
const updateResourceRuleParamsSchema = z.strictObject({
@@ -25,7 +26,7 @@ const updateResourceRuleParamsSchema = z.strictObject({
const updateResourceRuleSchema = z
.strictObject({
action: z.enum(["ACCEPT", "DROP", "PASS"]).optional(),
match: z.enum(["CIDR", "IP", "PATH", "COUNTRY", "ASN"]).optional(),
match: z.enum(["CIDR", "IP", "PATH", "COUNTRY", "ASN", "REGION"]).optional(),
value: z.string().min(1).optional(),
priority: z.int(),
enabled: z.boolean().optional()
@@ -166,6 +167,15 @@ export async function updateResourceRule(
)
);
}
} else if (match === "REGION") {
if (!isValidRegionId(value)) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid region ID provided"
)
);
}
}
}

View File

@@ -1,8 +1,8 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { roles, userOrgs } from "@server/db";
import { eq } from "drizzle-orm";
import { roles, userOrgRoles } from "@server/db";
import { and, eq, exists, aliasedTable } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -114,13 +114,32 @@ export async function deleteRole(
}
await db.transaction(async (trx) => {
// move all users from the userOrgs table with roleId to newRoleId
await trx
.update(userOrgs)
.set({ roleId: newRoleId })
.where(eq(userOrgs.roleId, roleId));
const uorNewRole = aliasedTable(userOrgRoles, "user_org_roles_new");
// Users who already have newRoleId: drop the old assignment only (unique on userId+orgId+roleId).
await trx.delete(userOrgRoles).where(
and(
eq(userOrgRoles.roleId, roleId),
exists(
trx
.select()
.from(uorNewRole)
.where(
and(
eq(uorNewRole.userId, userOrgRoles.userId),
eq(uorNewRole.orgId, userOrgRoles.orgId),
eq(uorNewRole.roleId, newRoleId)
)
)
)
)
);
await trx
.update(userOrgRoles)
.set({ roleId: newRoleId })
.where(eq(userOrgRoles.roleId, roleId));
// delete the old role
await trx.delete(roles).where(eq(roles.roleId, roleId));
});

View File

@@ -111,7 +111,7 @@ export async function createSite(
const { orgId } = parsedParams.data;
if (req.user && !req.userOrgRoleId) {
if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);
@@ -298,7 +298,8 @@ export async function createSite(
niceId,
address: updatedAddress || null,
type,
dockerSocketEnabled: true
dockerSocketEnabled: true,
status: "approved"
})
.returning();
} else if (type == "wireguard") {
@@ -355,7 +356,8 @@ export async function createSite(
niceId,
subnet,
type,
pubKey: pubKey || null
pubKey: pubKey || null,
status: "approved"
})
.returning();
} else if (type == "local") {
@@ -370,7 +372,8 @@ export async function createSite(
type,
dockerSocketEnabled: false,
online: true,
subnet: "0.0.0.0/32"
subnet: "0.0.0.0/32",
status: "approved"
})
.returning();
} else {
@@ -399,7 +402,7 @@ export async function createSite(
siteId: newSite.siteId
});
if (req.user && req.userOrgRoleId != adminRole[0].roleId) {
if (req.user && !req.userOrgRoleIds?.includes(adminRole[0].roleId)) {
// make sure the user can access the site
trx.insert(userSites).values({
userId: req.user?.userId!,

View File

@@ -55,7 +55,7 @@ async function getLatestNewtVersion(): Promise<string | null> {
tags = tags.filter((version) => !version.name.includes("rc"));
const latestVersion = tags[0].name;
await cache.set("latestNewtVersion", latestVersion);
await cache.set("latestNewtVersion", latestVersion, 3600);
return latestVersion;
} catch (error: any) {
@@ -135,6 +135,15 @@ const listSitesSchema = z.object({
.openapi({
type: "boolean",
description: "Filter by online status"
}),
status: z
.enum(["pending", "approved"])
.optional()
.catch(undefined)
.openapi({
type: "string",
enum: ["pending", "approved"],
description: "Filter by site status"
})
});
@@ -156,7 +165,8 @@ function querySitesBase() {
exitNodeId: sites.exitNodeId,
exitNodeName: exitNodes.name,
exitNodeEndpoint: exitNodes.endpoint,
remoteExitNodeId: remoteExitNodes.remoteExitNodeId
remoteExitNodeId: remoteExitNodes.remoteExitNodeId,
status: sites.status
})
.from(sites)
.leftJoin(orgs, eq(sites.orgId, orgs.orgId))
@@ -180,7 +190,7 @@ registry.registerPath({
method: "get",
path: "/org/{orgId}/sites",
description: "List all sites in an organization",
tags: [OpenAPITags.Site],
tags: [OpenAPITags.Org, OpenAPITags.Site],
request: {
params: listSitesParamsSchema,
query: listSitesSchema
@@ -235,7 +245,7 @@ export async function listSites(
.where(
or(
eq(userSites.userId, req.user!.userId),
eq(roleSites.roleId, req.userOrgRoleId!)
inArray(roleSites.roleId, req.userOrgRoleIds!)
)
);
} else {
@@ -245,7 +255,7 @@ export async function listSites(
.where(eq(sites.orgId, orgId));
}
const { pageSize, page, query, sort_by, order, online } =
const { pageSize, page, query, sort_by, order, online, status } =
parsedQuery.data;
const accessibleSiteIds = accessibleSites.map((site) => site.siteId);
@@ -273,6 +283,9 @@ export async function listSites(
if (typeof online !== "undefined") {
conditions.push(eq(sites.online, online));
}
if (typeof status !== "undefined") {
conditions.push(eq(sites.status, status));
}
const baseQuery = querySitesBase().where(and(...conditions));

View File

@@ -19,7 +19,8 @@ const updateSiteBodySchema = z
.strictObject({
name: z.string().min(1).max(255).optional(),
niceId: z.string().min(1).max(255).optional(),
dockerSocketEnabled: z.boolean().optional()
dockerSocketEnabled: z.boolean().optional(),
status: z.enum(["pending", "approved"]).optional(),
// remoteSubnets: z.string().optional()
// subdomain: z
// .string()

View File

@@ -0,0 +1,44 @@
export type SiteProvisioningKeyListItem = {
siteProvisioningKeyId: string;
orgId: string;
lastChars: string;
createdAt: string;
name: string;
lastUsed: string | null;
maxBatchSize: number | null;
numUsed: number;
validUntil: string | null;
approveNewSites: boolean;
};
export type ListSiteProvisioningKeysResponse = {
siteProvisioningKeys: SiteProvisioningKeyListItem[];
pagination: { total: number; limit: number; offset: number };
};
export type CreateSiteProvisioningKeyResponse = {
siteProvisioningKeyId: string;
orgId: string;
name: string;
siteProvisioningKey: string;
lastChars: string;
createdAt: string;
lastUsed: string | null;
maxBatchSize: number | null;
numUsed: number;
validUntil: string | null;
approveNewSites: boolean;
};
export type UpdateSiteProvisioningKeyResponse = {
siteProvisioningKeyId: string;
orgId: string;
name: string;
lastChars: string;
createdAt: string;
lastUsed: string | null;
maxBatchSize: number | null;
numUsed: number;
validUntil: string | null;
approveNewSites: boolean;
};

View File

@@ -90,7 +90,7 @@ const createSiteResourceSchema = z
},
{
message:
"Destination must be a valid IP address or valid domain AND alias is required"
"Destination must be a valid IPV4 address or valid domain AND alias is required"
}
)
.refine(

View File

@@ -1,5 +1,4 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import {
clientSiteResources,
clientSiteResourcesAssociationsCache,
@@ -9,33 +8,32 @@ import {
roles,
roleSiteResources,
siteNetworks,
SiteResource,
siteResources,
sites,
networks,
Transaction,
userSiteResources
} from "@server/db";
import { siteResources, SiteResource } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { eq, and, ne, inArray } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { OpenAPITags, registry } from "@server/openApi";
import { updatePeerData, updateTargets } from "@server/routers/client/targets";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import {
generateAliasConfig,
generateRemoteSubnets,
generateSubnetProxyTargets,
generateSubnetProxyTargetV2,
isIpInCidr,
portRangeStringSchema
} from "@server/lib/ip";
import {
getClientSiteResourceAccess,
rebuildClientAssociationsFromSiteResource
} from "@server/lib/rebuildClientAssociations";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { rebuildClientAssociationsFromSiteResource } from "@server/lib/rebuildClientAssociations";
import logger from "@server/logger";
import HttpCode from "@server/types/HttpCode";
import { NextFunction, Request, Response } from "express";
import createHttpError from "http-errors";
import { z } from "zod";
import { fromError } from "zod-validation-error";
const updateSiteResourceParamsSchema = z.strictObject({
siteResourceId: z.string().transform(Number).pipe(z.int().positive())
@@ -46,6 +44,15 @@ const updateSiteResourceSchema = z
name: z.string().min(1).max(255).optional(),
siteIds: z.array(z.int()),
// niceId: z.string().min(1).max(255).regex(/^[a-zA-Z0-9-]+$/, "niceId can only contain letters, numbers, and dashes").optional(),
niceId: z
.string()
.min(1)
.max(255)
.regex(
/^[a-zA-Z0-9-]+$/,
"niceId can only contain letters, numbers, and dashes"
)
.optional(),
// mode: z.enum(["host", "cidr", "port"]).optional(),
mode: z.enum(["host", "cidr"]).optional(),
// protocol: z.enum(["tcp", "udp"]).nullish(),
@@ -169,6 +176,7 @@ export async function updateSiteResource(
const {
name,
siteIds, // because it can change
niceId,
mode,
destination,
alias,
@@ -344,14 +352,15 @@ export async function updateSiteResource(
[updatedSiteResource] = await trx
.update(siteResources)
.set({
name: name,
mode: mode,
destination: destination,
enabled: enabled,
name,
niceId,
mode,
destination,
enabled,
alias: alias && alias.trim() ? alias : null,
tcpPortRangeString: tcpPortRangeString,
udpPortRangeString: udpPortRangeString,
disableIcmp: disableIcmp,
tcpPortRangeString,
udpPortRangeString,
disableIcmp,
...sshPamSet
})
.where(
@@ -467,7 +476,10 @@ export async function updateSiteResource(
await trx
.delete(siteNetworks)
.where(
eq(siteNetworks.networkId, updatedSiteResource.networkId!)
eq(
siteNetworks.networkId,
updatedSiteResource.networkId!
)
);
for (const siteId of siteIds) {
@@ -546,9 +558,7 @@ export async function updateSiteResource(
);
}
logger.info(
`Updated site resource ${siteResourceId}`
);
logger.info(`Updated site resource ${siteResourceId}`);
await handleMessagingForUpdatedSiteResource(
existingSiteResource,
@@ -635,11 +645,11 @@ export async function handleMessagingForUpdatedSiteResource(
// Only update targets on newt if destination changed
if (destinationChanged || portRangesChanged) {
const oldTargets = generateSubnetProxyTargets(
const oldTarget = generateSubnetProxyTargetV2(
existingSiteResource,
mergedAllClients
);
const newTargets = generateSubnetProxyTargets(
const newTarget = generateSubnetProxyTargetV2(
updatedSiteResource,
mergedAllClients
);
@@ -647,8 +657,8 @@ export async function handleMessagingForUpdatedSiteResource(
await updateTargets(
newt.newtId,
{
oldTargets: oldTargets,
newTargets: newTargets
oldTargets: oldTarget ? [oldTarget] : [],
newTargets: newTarget ? [newTarget] : []
},
newt.version
);
@@ -670,10 +680,7 @@ export async function handleMessagingForUpdatedSiteResource(
)
.innerJoin(
siteNetworks,
eq(
siteNetworks.networkId,
siteResources.networkId
)
eq(siteNetworks.networkId, siteResources.networkId)
)
.where(
and(
@@ -693,7 +700,6 @@ export async function handleMessagingForUpdatedSiteResource(
)
);
const oldDestinationStillInUseByASite =
oldDestinationStillInUseSites.length > 0;

View File

@@ -77,7 +77,8 @@ export const handleHealthcheckStatusMessage: MessageHandler = async (
const [targetCheck] = await db
.select({
targetId: targets.targetId,
siteId: targets.siteId
siteId: targets.siteId,
hcStatus: targetHealthCheck.hcHealth
})
.from(targets)
.innerJoin(
@@ -85,6 +86,7 @@ export const handleHealthcheckStatusMessage: MessageHandler = async (
eq(targets.resourceId, resources.resourceId)
)
.innerJoin(sites, eq(targets.siteId, sites.siteId))
.innerJoin(targetHealthCheck, eq(targets.targetId, targetHealthCheck.targetId))
.where(
and(
eq(targets.targetId, targetIdNum),
@@ -101,6 +103,14 @@ export const handleHealthcheckStatusMessage: MessageHandler = async (
continue;
}
// check if the status has changed
if (targetCheck.hcStatus === healthStatus.status) {
logger.debug(
`Health status for target ${targetId} is already ${healthStatus.status}, skipping update`
);
continue;
}
// Update the target's health status in the database
await db
.update(targetHealthCheck)

View File

@@ -40,6 +40,7 @@ function queryTargets(resourceId: number) {
resourceId: targets.resourceId,
siteId: targets.siteId,
siteType: sites.type,
siteName: sites.name,
hcEnabled: targetHealthCheck.hcEnabled,
hcPath: targetHealthCheck.hcPath,
hcScheme: targetHealthCheck.hcScheme,

View File

@@ -188,6 +188,8 @@ export async function updateTarget(
);
}
const pathMatchTypeRemoved = parsedBody.data.pathMatchType === null;
const [updatedTarget] = await db
.update(targets)
.set({
@@ -200,8 +202,8 @@ export async function updateTarget(
path: parsedBody.data.path,
pathMatchType: parsedBody.data.pathMatchType,
priority: parsedBody.data.priority,
rewritePath: parsedBody.data.rewritePath,
rewritePathType: parsedBody.data.rewritePathType
rewritePath: pathMatchTypeRemoved ? null : parsedBody.data.rewritePath,
rewritePathType: pathMatchTypeRemoved ? null : parsedBody.data.rewritePathType
})
.where(eq(targets.targetId, targetId))
.returning();

View File

@@ -30,12 +30,15 @@ export async function traefikConfigProvider(
traefikConfig.http.middlewares[badgerMiddlewareName] = {
plugin: {
[badgerMiddlewareName]: {
apiBaseUrl: new URL(
"/api/v1",
`http://${
config.getRawConfig().server.internal_hostname
}:${config.getRawConfig().server.internal_port}`
).href,
apiBaseUrl:
config.getRawConfig().server.badger_override ||
new URL(
"/api/v1",
`http://${
config.getRawConfig().server
.internal_hostname
}:${config.getRawConfig().server.internal_port}`
).href,
userSessionCookieName:
config.getRawConfig().server.session_cookie_name,
@@ -61,7 +64,7 @@ export async function traefikConfigProvider(
return res.status(HttpCode.OK).json(traefikConfig);
} catch (e) {
logger.error(`Failed to build Traefik config: ${e}`);
logger.error(e);
return res.status(HttpCode.INTERNAL_SERVER_ERROR).json({
error: "Failed to build Traefik config"
});

View File

@@ -1,8 +1,8 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, orgs, UserOrg } from "@server/db";
import { roles, userInvites, userOrgs, users } from "@server/db";
import { eq, and, inArray, ne } from "drizzle-orm";
import { db, orgs } from "@server/db";
import { roles, userInviteRoles, userInvites, userOrgs, users } from "@server/db";
import { eq, and, inArray } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -141,17 +141,34 @@ export async function acceptInvite(
);
}
let roleId: number;
// get the role to make sure it exists
const existingRole = await db
const inviteRoleRows = await db
.select({ roleId: userInviteRoles.roleId })
.from(userInviteRoles)
.where(eq(userInviteRoles.inviteId, inviteId));
const inviteRoleIds = [
...new Set(inviteRoleRows.map((r) => r.roleId))
];
if (inviteRoleIds.length === 0) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"This invitation has no roles. Please contact an admin."
)
);
}
const existingRoles = await db
.select()
.from(roles)
.where(eq(roles.roleId, existingInvite.roleId))
.limit(1);
if (existingRole.length) {
roleId = existingRole[0].roleId;
} else {
// TODO: use the default role on the org instead of failing
.where(
and(
eq(roles.orgId, existingInvite.orgId),
inArray(roles.roleId, inviteRoleIds)
)
);
if (existingRoles.length !== inviteRoleIds.length) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
@@ -165,9 +182,9 @@ export async function acceptInvite(
org,
{
userId: existingUser[0].userId,
orgId: existingInvite.orgId,
roleId: existingInvite.roleId
orgId: existingInvite.orgId
},
inviteRoleIds,
trx
);

View File

@@ -1,42 +1,44 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { clients, db, UserOrg } from "@server/db";
import { userOrgs, roles } from "@server/db";
import stoi from "@server/lib/stoi";
import { clients, db } from "@server/db";
import { userOrgRoles, userOrgs, roles } from "@server/db";
import { eq, and } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import stoi from "@server/lib/stoi";
import { OpenAPITags, registry } from "@server/openApi";
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
const addUserRoleParamsSchema = z.strictObject({
userId: z.string(),
roleId: z.string().transform(stoi).pipe(z.number())
/** Legacy path param order: /role/:roleId/add/:userId */
const addUserRoleLegacyParamsSchema = z.strictObject({
roleId: z.string().transform(stoi).pipe(z.number()),
userId: z.string()
});
export type AddUserRoleResponse = z.infer<typeof addUserRoleParamsSchema>;
registry.registerPath({
method: "post",
path: "/role/{roleId}/add/{userId}",
description: "Add a role to a user.",
description:
"Legacy: set exactly one role for the user (replaces any other roles the user has in the org).",
tags: [OpenAPITags.Role, OpenAPITags.User],
request: {
params: addUserRoleParamsSchema
params: addUserRoleLegacyParamsSchema
},
responses: {}
});
export async function addUserRole(
export async function addUserRoleLegacy(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const parsedParams = addUserRoleParamsSchema.safeParse(req.params);
const parsedParams = addUserRoleLegacyParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
@@ -57,7 +59,6 @@ export async function addUserRole(
);
}
// get the role
const [role] = await db
.select()
.from(roles)
@@ -70,7 +71,7 @@ export async function addUserRole(
);
}
const existingUser = await db
const [existingUser] = await db
.select()
.from(userOrgs)
.where(
@@ -78,7 +79,7 @@ export async function addUserRole(
)
.limit(1);
if (existingUser.length === 0) {
if (!existingUser) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
@@ -87,7 +88,7 @@ export async function addUserRole(
);
}
if (existingUser[0].isOwner) {
if (existingUser.isOwner) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
@@ -96,13 +97,13 @@ export async function addUserRole(
);
}
const roleExists = await db
const [roleInOrg] = await db
.select()
.from(roles)
.where(and(eq(roles.roleId, roleId), eq(roles.orgId, role.orgId)))
.limit(1);
if (roleExists.length === 0) {
if (!roleInOrg) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
@@ -111,20 +112,22 @@ export async function addUserRole(
);
}
let newUserRole: UserOrg | null = null;
await db.transaction(async (trx) => {
[newUserRole] = await trx
.update(userOrgs)
.set({ roleId })
await trx
.delete(userOrgRoles)
.where(
and(
eq(userOrgs.userId, userId),
eq(userOrgs.orgId, role.orgId)
eq(userOrgRoles.userId, userId),
eq(userOrgRoles.orgId, role.orgId)
)
)
.returning();
);
await trx.insert(userOrgRoles).values({
userId,
orgId: role.orgId,
roleId
});
// get the client associated with this user in this org
const orgClients = await trx
.select()
.from(clients)
@@ -133,17 +136,15 @@ export async function addUserRole(
eq(clients.userId, userId),
eq(clients.orgId, role.orgId)
)
)
.limit(1);
);
for (const orgClient of orgClients) {
// we just changed the user's role, so we need to rebuild client associations and what they have access to
await rebuildClientAssociationsFromClient(orgClient, trx);
}
});
return response(res, {
data: newUserRole,
data: { ...existingUser, roleId },
success: true,
error: false,
message: "Role added to user successfully",

View File

@@ -6,8 +6,8 @@ import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { db, orgs, UserOrg } from "@server/db";
import { and, eq, inArray, ne } from "drizzle-orm";
import { db, orgs } from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import { idp, idpOidcConfig, roles, userOrgs, users } from "@server/db";
import { generateId } from "@server/auth/sessions/app";
import { usageService } from "@server/lib/billing/usageService";
@@ -15,21 +15,43 @@ import { FeatureId } from "@server/lib/billing";
import { build } from "@server/build";
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
import { isSubscribed } from "#dynamic/lib/isSubscribed";
import { tierMatrix } from "@server/lib/billing/tierMatrix";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { assignUserToOrg } from "@server/lib/userOrg";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
const paramsSchema = z.strictObject({
orgId: z.string().nonempty()
});
const bodySchema = z.strictObject({
email: z.string().email().toLowerCase().optional(),
username: z.string().nonempty().toLowerCase(),
name: z.string().optional(),
type: z.enum(["internal", "oidc"]).optional(),
idpId: z.number().optional(),
roleId: z.number()
});
const bodySchema = z
.strictObject({
email: z.string().email().toLowerCase().optional(),
username: z.string().nonempty().toLowerCase(),
name: z.string().optional(),
type: z.enum(["internal", "oidc"]).optional(),
idpId: z.number().optional(),
roleIds: z.array(z.number().int().positive()).min(1).optional(),
roleId: z.number().int().positive().optional()
})
.refine(
(d) =>
(d.roleIds != null && d.roleIds.length > 0) || d.roleId != null,
{ message: "roleIds or roleId is required", path: ["roleIds"] }
)
.transform((data) => ({
email: data.email,
username: data.username,
name: data.name,
type: data.type,
idpId: data.idpId,
roleIds: [
...new Set(
data.roleIds && data.roleIds.length > 0
? data.roleIds
: [data.roleId!]
)
]
}));
export type CreateOrgUserResponse = {};
@@ -78,7 +100,8 @@ export async function createOrgUser(
}
const { orgId } = parsedParams.data;
const { username, email, name, type, idpId, roleId } = parsedBody.data;
const { username, email, name, type, idpId, roleIds: uniqueRoleIds } =
parsedBody.data;
if (build == "saas") {
const usage = await usageService.getUsage(orgId, FeatureId.USERS);
@@ -109,17 +132,6 @@ export async function createOrgUser(
}
}
const [role] = await db
.select()
.from(roles)
.where(eq(roles.roleId, roleId));
if (!role) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Role ID not found")
);
}
if (type === "internal") {
return next(
createHttpError(
@@ -152,6 +164,38 @@ export async function createOrgUser(
);
}
const supportsMultiRole = await isLicensedOrSubscribed(
orgId,
tierMatrix[TierFeature.FullRbac]
);
if (!supportsMultiRole && uniqueRoleIds.length > 1) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Multiple roles per user require a subscription or license that includes full RBAC."
)
);
}
const orgRoles = await db
.select({ roleId: roles.roleId })
.from(roles)
.where(
and(
eq(roles.orgId, orgId),
inArray(roles.roleId, uniqueRoleIds)
)
);
if (orgRoles.length !== uniqueRoleIds.length) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Invalid role ID or role does not belong to this organization"
)
);
}
const [org] = await db
.select()
.from(orgs)
@@ -221,12 +265,16 @@ export async function createOrgUser(
);
}
await assignUserToOrg(org, {
orgId,
userId: existingUser.userId,
roleId: role.roleId,
autoProvisioned: false
}, trx);
await assignUserToOrg(
org,
{
orgId,
userId: existingUser.userId,
autoProvisioned: false,
},
uniqueRoleIds,
trx
);
} else {
userId = generateId(15);
@@ -244,12 +292,16 @@ export async function createOrgUser(
})
.returning();
await assignUserToOrg(org, {
orgId,
userId: newUser.userId,
roleId: role.roleId,
autoProvisioned: false
}, trx);
await assignUserToOrg(
org,
{
orgId,
userId: newUser.userId,
autoProvisioned: false,
},
uniqueRoleIds,
trx
);
}
await calculateUserClientsForOrgs(userId, trx);

View File

@@ -1,7 +1,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, idp, idpOidcConfig } from "@server/db";
import { roles, userOrgs, users } from "@server/db";
import { roles, userOrgRoles, userOrgs, users } from "@server/db";
import { and, eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
@@ -12,7 +12,7 @@ import { ActionsEnum, checkUserActionPermission } from "@server/auth/actions";
import { OpenAPITags, registry } from "@server/openApi";
export async function queryUser(orgId: string, userId: string) {
const [user] = await db
const [userRow] = await db
.select({
orgId: userOrgs.orgId,
userId: users.userId,
@@ -20,10 +20,7 @@ export async function queryUser(orgId: string, userId: string) {
username: users.username,
name: users.name,
type: users.type,
roleId: userOrgs.roleId,
roleName: roles.name,
isOwner: userOrgs.isOwner,
isAdmin: roles.isAdmin,
twoFactorEnabled: users.twoFactorEnabled,
autoProvisioned: userOrgs.autoProvisioned,
idpId: users.idpId,
@@ -33,13 +30,40 @@ export async function queryUser(orgId: string, userId: string) {
idpAutoProvision: idp.autoProvision
})
.from(userOrgs)
.leftJoin(roles, eq(userOrgs.roleId, roles.roleId))
.leftJoin(users, eq(userOrgs.userId, users.userId))
.leftJoin(idp, eq(users.idpId, idp.idpId))
.leftJoin(idpOidcConfig, eq(idp.idpId, idpOidcConfig.idpId))
.where(and(eq(userOrgs.userId, userId), eq(userOrgs.orgId, orgId)))
.limit(1);
return user;
if (!userRow) return undefined;
const roleRows = await db
.select({
roleId: userOrgRoles.roleId,
roleName: roles.name,
isAdmin: roles.isAdmin
})
.from(userOrgRoles)
.leftJoin(roles, eq(userOrgRoles.roleId, roles.roleId))
.where(
and(
eq(userOrgRoles.userId, userId),
eq(userOrgRoles.orgId, orgId)
)
);
const isAdmin = roleRows.some((r) => r.isAdmin);
return {
...userRow,
isAdmin,
roleIds: roleRows.map((r) => r.roleId),
roles: roleRows.map((r) => ({
roleId: r.roleId,
name: r.roleName ?? ""
}))
};
}
export type GetOrgUserResponse = NonNullable<

View File

@@ -20,7 +20,8 @@ async function queryUser(userId: string) {
emailVerified: users.emailVerified,
serverAdmin: users.serverAdmin,
idpName: idp.name,
idpId: users.idpId
idpId: users.idpId,
locale: users.locale
})
.from(users)
.leftJoin(idp, eq(users.idpId, idp.idpId))

View File

@@ -1,7 +1,8 @@
export * from "./getUser";
export * from "./removeUserOrg";
export * from "./listUsers";
export * from "./addUserRole";
export * from "./types";
export * from "./addUserRoleLegacy";
export * from "./inviteUser";
export * from "./acceptInvite";
export * from "./getOrgUser";
@@ -16,4 +17,5 @@ export * from "./createOrgUser";
export * from "./adminUpdateUser2FA";
export * from "./adminGetUser";
export * from "./updateOrgUser";
export * from "./updateUserLocale";
export * from "./myDevice";

View File

@@ -1,8 +1,8 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { orgs, roles, userInvites, userOrgs, users } from "@server/db";
import { and, eq } from "drizzle-orm";
import { orgs, roles, userInviteRoles, userInvites, userOrgs, users } from "@server/db";
import { and, eq, inArray } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
@@ -18,22 +18,44 @@ import { OpenAPITags, registry } from "@server/openApi";
import { UserType } from "@server/types/UserTypes";
import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
import { TierFeature, tierMatrix } from "@server/lib/billing/tierMatrix";
import { build } from "@server/build";
import cache from "#dynamic/lib/cache";
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
const inviteUserParamsSchema = z.strictObject({
orgId: z.string()
});
const inviteUserBodySchema = z.strictObject({
email: z.email().toLowerCase(),
roleId: z.number(),
validHours: z.number().gt(0).lte(168),
sendEmail: z.boolean().optional(),
regenerate: z.boolean().optional()
});
const inviteUserBodySchema = z
.strictObject({
email: z.email().toLowerCase(),
roleIds: z.array(z.number().int().positive()).min(1).optional(),
roleId: z.number().int().positive().optional(),
validHours: z.number().gt(0).lte(168),
sendEmail: z.boolean().optional(),
regenerate: z.boolean().optional()
})
.refine(
(d) =>
(d.roleIds != null && d.roleIds.length > 0) || d.roleId != null,
{ message: "roleIds or roleId is required", path: ["roleIds"] }
)
.transform((data) => ({
email: data.email,
validHours: data.validHours,
sendEmail: data.sendEmail,
regenerate: data.regenerate,
roleIds: [
...new Set(
data.roleIds && data.roleIds.length > 0
? data.roleIds
: [data.roleId!]
)
]
}));
export type InviteUserBody = z.infer<typeof inviteUserBodySchema>;
export type InviteUserBody = z.input<typeof inviteUserBodySchema>;
export type InviteUserResponse = {
inviteLink: string;
@@ -88,7 +110,7 @@ export async function inviteUser(
const {
email,
validHours,
roleId,
roleIds: uniqueRoleIds,
sendEmail: doEmail,
regenerate
} = parsedBody.data;
@@ -105,14 +127,30 @@ export async function inviteUser(
);
}
// Validate that the roleId belongs to the target organization
const [role] = await db
.select()
.from(roles)
.where(and(eq(roles.roleId, roleId), eq(roles.orgId, orgId)))
.limit(1);
const supportsMultiRole = await isLicensedOrSubscribed(
orgId,
tierMatrix[TierFeature.FullRbac]
);
if (!supportsMultiRole && uniqueRoleIds.length > 1) {
return next(
createHttpError(
HttpCode.FORBIDDEN,
"Multiple roles per user require a subscription or license that includes full RBAC."
)
);
}
if (!role) {
const orgRoles = await db
.select({ roleId: roles.roleId })
.from(roles)
.where(
and(
eq(roles.orgId, orgId),
inArray(roles.roleId, uniqueRoleIds)
)
);
if (orgRoles.length !== uniqueRoleIds.length) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
@@ -191,7 +229,8 @@ export async function inviteUser(
}
if (existingInvite.length) {
const attempts = (await cache.get<number>(email)) || 0;
const attempts =
(await cache.get<number>("regenerateInvite:" + email)) || 0;
if (attempts >= 3) {
return next(
createHttpError(
@@ -201,7 +240,7 @@ export async function inviteUser(
);
}
await cache.set(email, attempts + 1);
await cache.set("regenerateInvite:" + email, attempts + 1, 3600);
const inviteId = existingInvite[0].inviteId; // Retrieve the original inviteId
const token = generateRandomString(
@@ -273,9 +312,11 @@ export async function inviteUser(
orgId,
email,
expiresAt,
tokenHash,
roleId
tokenHash
});
await trx.insert(userInviteRoles).values(
uniqueRoleIds.map((roleId) => ({ inviteId, roleId }))
);
});
const inviteLink = `${config.getRawConfig().app.dashboard_url}/invite?token=${inviteId}-${token}&email=${encodeURIComponent(email)}`;

View File

@@ -1,11 +1,11 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { userInvites, roles } from "@server/db";
import { userInvites, userInviteRoles, roles } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { sql } from "drizzle-orm";
import { sql, eq, and, inArray } from "drizzle-orm";
import logger from "@server/logger";
import { fromZodError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
@@ -29,24 +29,66 @@ const listInvitationsQuerySchema = z.strictObject({
.pipe(z.int().nonnegative())
});
async function queryInvitations(orgId: string, limit: number, offset: number) {
return await db
export type InvitationListRow = {
inviteId: string;
email: string;
expiresAt: number;
roles: { roleId: number; roleName: string | null }[];
};
async function queryInvitations(
orgId: string,
limit: number,
offset: number
): Promise<InvitationListRow[]> {
const inviteRows = await db
.select({
inviteId: userInvites.inviteId,
email: userInvites.email,
expiresAt: userInvites.expiresAt,
roleId: userInvites.roleId,
roleName: roles.name
expiresAt: userInvites.expiresAt
})
.from(userInvites)
.leftJoin(roles, sql`${userInvites.roleId} = ${roles.roleId}`)
.where(sql`${userInvites.orgId} = ${orgId}`)
.where(eq(userInvites.orgId, orgId))
.limit(limit)
.offset(offset);
if (inviteRows.length === 0) {
return [];
}
const inviteIds = inviteRows.map((r) => r.inviteId);
const roleRows = await db
.select({
inviteId: userInviteRoles.inviteId,
roleId: userInviteRoles.roleId,
roleName: roles.name
})
.from(userInviteRoles)
.innerJoin(roles, eq(userInviteRoles.roleId, roles.roleId))
.where(
and(eq(roles.orgId, orgId), inArray(userInviteRoles.inviteId, inviteIds))
);
const rolesByInvite = new Map<
string,
{ roleId: number; roleName: string | null }[]
>();
for (const row of roleRows) {
const list = rolesByInvite.get(row.inviteId) ?? [];
list.push({ roleId: row.roleId, roleName: row.roleName });
rolesByInvite.set(row.inviteId, list);
}
return inviteRows.map((inv) => ({
inviteId: inv.inviteId,
email: inv.email,
expiresAt: inv.expiresAt,
roles: rolesByInvite.get(inv.inviteId) ?? []
}));
}
export type ListInvitationsResponse = {
invitations: NonNullable<Awaited<ReturnType<typeof queryInvitations>>>;
invitations: InvitationListRow[];
pagination: { total: number; limit: number; offset: number };
};
@@ -95,7 +137,7 @@ export async function listInvitations(
const [{ count }] = await db
.select({ count: sql<number>`count(*)` })
.from(userInvites)
.where(sql`${userInvites.orgId} = ${orgId}`);
.where(eq(userInvites.orgId, orgId));
return response<ListInvitationsResponse>(res, {
data: {

View File

@@ -1,15 +1,14 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, idpOidcConfig } from "@server/db";
import { idp, roles, userOrgs, users } from "@server/db";
import { idp, roles, userOrgRoles, userOrgs, users } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import { and, sql } from "drizzle-orm";
import { and, eq, inArray, sql } from "drizzle-orm";
import logger from "@server/logger";
import { fromZodError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { eq } from "drizzle-orm";
const listUsersParamsSchema = z.strictObject({
orgId: z.string()
@@ -31,7 +30,7 @@ const listUsersSchema = z.strictObject({
});
async function queryUsers(orgId: string, limit: number, offset: number) {
return await db
const rows = await db
.select({
id: users.userId,
email: users.email,
@@ -41,8 +40,6 @@ async function queryUsers(orgId: string, limit: number, offset: number) {
username: users.username,
name: users.name,
type: users.type,
roleId: userOrgs.roleId,
roleName: roles.name,
isOwner: userOrgs.isOwner,
idpName: idp.name,
idpId: users.idpId,
@@ -52,12 +49,48 @@ async function queryUsers(orgId: string, limit: number, offset: number) {
})
.from(users)
.leftJoin(userOrgs, eq(users.userId, userOrgs.userId))
.leftJoin(roles, eq(userOrgs.roleId, roles.roleId))
.leftJoin(idp, eq(users.idpId, idp.idpId))
.leftJoin(idpOidcConfig, eq(idpOidcConfig.idpId, idp.idpId))
.where(eq(userOrgs.orgId, orgId))
.limit(limit)
.offset(offset);
const userIds = rows.map((r) => r.id);
const roleRows =
userIds.length === 0
? []
: await db
.select({
userId: userOrgRoles.userId,
roleId: userOrgRoles.roleId,
roleName: roles.name
})
.from(userOrgRoles)
.leftJoin(roles, eq(userOrgRoles.roleId, roles.roleId))
.where(
and(
eq(userOrgRoles.orgId, orgId),
inArray(userOrgRoles.userId, userIds)
)
);
const rolesByUser = new Map<
string,
{ roleId: number; roleName: string }[]
>();
for (const r of roleRows) {
const list = rolesByUser.get(r.userId) ?? [];
list.push({ roleId: r.roleId, roleName: r.roleName ?? "" });
rolesByUser.set(r.userId, list);
}
return rows.map((row) => {
const userRoles = rolesByUser.get(row.id) ?? [];
return {
...row,
roles: userRoles
};
});
}
export type ListUsersResponse = {

View File

@@ -1,5 +1,5 @@
import { Request, Response, NextFunction } from "express";
import { db, Olm, olms, orgs, userOrgs } from "@server/db";
import { db, Olm, olms, orgs, userOrgRoles, userOrgs } from "@server/db";
import { idp, users } from "@server/db";
import { and, eq } from "drizzle-orm";
import response from "@server/lib/response";
@@ -63,7 +63,8 @@ export async function myDevice(
emailVerified: users.emailVerified,
serverAdmin: users.serverAdmin,
idpName: idp.name,
idpId: users.idpId
idpId: users.idpId,
locale: users.locale
})
.from(users)
.leftJoin(idp, eq(users.idpId, idp.idpId))
@@ -84,16 +85,31 @@ export async function myDevice(
.from(olms)
.where(and(eq(olms.userId, userId), eq(olms.olmId, olmId)));
const userOrganizations = await db
const userOrgRows = await db
.select({
orgId: userOrgs.orgId,
orgName: orgs.name,
roleId: userOrgs.roleId
orgName: orgs.name
})
.from(userOrgs)
.where(eq(userOrgs.userId, userId))
.innerJoin(orgs, eq(userOrgs.orgId, orgs.orgId));
const roleRows = await db
.select({
orgId: userOrgRoles.orgId,
roleId: userOrgRoles.roleId
})
.from(userOrgRoles)
.where(eq(userOrgRoles.userId, userId));
const roleByOrg = new Map(
roleRows.map((r) => [r.orgId, r.roleId])
);
const userOrganizations = userOrgRows.map((row) => ({
...row,
roleId: roleByOrg.get(row.orgId) ?? 0
}));
return response<MyDeviceResponse>(res, {
data: {
user,

View File

@@ -0,0 +1,18 @@
import type { UserOrg } from "@server/db";
export type AddUserRoleResponse = {
userId: string;
roleId: number;
};
/** Legacy POST /role/:roleId/add/:userId response shape (membership + effective role). */
export type AddUserRoleLegacyResponse = UserOrg & { roleId: number };
export type SetUserOrgRolesParams = {
orgId: string;
userId: string;
};
export type SetUserOrgRolesBody = {
roleIds: number[];
};

View File

@@ -0,0 +1,57 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db } from "@server/db";
import { users } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
import createHttpError from "http-errors";
import logger from "@server/logger";
import { fromError } from "zod-validation-error";
const bodySchema = z.strictObject({
locale: z.string().min(2).max(10)
});
export async function updateUserLocale(
req: Request,
res: Response,
next: NextFunction
): Promise<any> {
try {
const userId = req.user?.userId;
if (!userId) {
return next(
createHttpError(HttpCode.UNAUTHORIZED, "User not found")
);
}
const parsedBody = bodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { locale } = parsedBody.data;
await db.update(users).set({ locale }).where(eq(users.userId, userId));
return response(res, {
data: null,
success: true,
error: false,
message: "User locale updated successfully",
status: HttpCode.OK
});
} catch (error) {
logger.error(error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}

View File

@@ -11,6 +11,7 @@ import {
startNewtOfflineChecker,
handleNewtDisconnectingMessage
} from "../newt";
import { startPingAccumulator } from "../newt/pingAccumulator";
import {
handleOlmRegisterMessage,
handleOlmRelayMessage,
@@ -46,6 +47,10 @@ export const messageHandlers: Record<string, MessageHandler> = {
"ws/round-trip/complete": handleRoundTripMessage
};
// Start the ping accumulator for all builds — it batches per-site online/lastPing
// updates into periodic bulk writes, preventing connection pool exhaustion.
startPingAccumulator();
if (build != "saas") {
startOlmOfflineChecker(); // this is to handle the offline check for olms
startNewtOfflineChecker(); // this is to handle the offline check for newts

View File

@@ -6,6 +6,7 @@ import { Socket } from "net";
import { Newt, newts, NewtSession, olms, Olm, OlmSession, sites } from "@server/db";
import { eq } from "drizzle-orm";
import { db } from "@server/db";
import { recordPing } from "@server/routers/newt/pingAccumulator";
import { validateNewtSessionToken } from "@server/auth/sessions/newt";
import { validateOlmSessionToken } from "@server/auth/sessions/olm";
import { messageHandlers } from "./messageHandlers";
@@ -386,22 +387,14 @@ const setupConnection = async (
// the same as modern newt clients.
if (clientType === "newt") {
const newtClient = client as Newt;
ws.on("ping", async () => {
ws.on("ping", () => {
if (!newtClient.siteId) return;
try {
await db
.update(sites)
.set({
online: true,
lastPing: Math.floor(Date.now() / 1000)
})
.where(eq(sites.siteId, newtClient.siteId));
} catch (error) {
logger.error(
"Error updating newt site online state on WS ping",
{ error }
);
}
// Record the ping in the accumulator instead of writing to the
// database on every WS ping frame. The accumulator flushes all
// pending pings in a single batched UPDATE every ~10s, which
// prevents connection pool exhaustion under load (especially
// with cross-region latency to the database).
recordPing(newtClient.siteId);
});
}