mirror of
https://github.com/fosrl/pangolin.git
synced 2026-03-31 15:06:42 +00:00
Merge branch 'dev' into feature/region-rules
This commit is contained in:
@@ -43,7 +43,7 @@ registry.registerPath({
|
||||
method: "post",
|
||||
path: "/resource/{resourceId}/access-token",
|
||||
description: "Generate a new access token for a resource.",
|
||||
tags: [OpenAPITags.Resource, OpenAPITags.AccessToken],
|
||||
tags: [OpenAPITags.PublicResource, OpenAPITags.AccessToken],
|
||||
request: {
|
||||
params: generateAccssTokenParamsSchema,
|
||||
body: {
|
||||
|
||||
@@ -122,7 +122,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/access-tokens",
|
||||
description: "List all access tokens in an organization.",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.AccessToken],
|
||||
tags: [OpenAPITags.AccessToken],
|
||||
request: {
|
||||
params: z.object({
|
||||
orgId: z.string()
|
||||
@@ -135,8 +135,8 @@ registry.registerPath({
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
path: "/resource/{resourceId}/access-tokens",
|
||||
description: "List all access tokens in an organization.",
|
||||
tags: [OpenAPITags.Resource, OpenAPITags.AccessToken],
|
||||
description: "List all access tokens for a resource.",
|
||||
tags: [OpenAPITags.PublicResource, OpenAPITags.AccessToken],
|
||||
request: {
|
||||
params: z.object({
|
||||
resourceId: z.number()
|
||||
@@ -208,7 +208,7 @@ export async function listAccessTokens(
|
||||
.where(
|
||||
or(
|
||||
eq(userResources.userId, req.user!.userId),
|
||||
eq(roleResources.roleId, req.userOrgRoleId!)
|
||||
inArray(roleResources.roleId, req.userOrgRoleIds!)
|
||||
)
|
||||
);
|
||||
} else {
|
||||
|
||||
@@ -37,7 +37,7 @@ registry.registerPath({
|
||||
method: "put",
|
||||
path: "/org/{orgId}/api-key",
|
||||
description: "Create a new API key scoped to the organization.",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.ApiKey],
|
||||
tags: [OpenAPITags.ApiKey],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
body: {
|
||||
|
||||
@@ -18,7 +18,7 @@ registry.registerPath({
|
||||
method: "delete",
|
||||
path: "/org/{orgId}/api-key/{apiKeyId}",
|
||||
description: "Delete an API key.",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.ApiKey],
|
||||
tags: [OpenAPITags.ApiKey],
|
||||
request: {
|
||||
params: paramsSchema
|
||||
},
|
||||
|
||||
@@ -48,7 +48,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/api-key/{apiKeyId}/actions",
|
||||
description: "List all actions set for an API key.",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.ApiKey],
|
||||
tags: [OpenAPITags.ApiKey],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
query: querySchema
|
||||
|
||||
@@ -52,7 +52,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/api-keys",
|
||||
description: "List all API keys for an organization",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.ApiKey],
|
||||
tags: [OpenAPITags.ApiKey],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
query: querySchema
|
||||
|
||||
@@ -25,7 +25,7 @@ registry.registerPath({
|
||||
path: "/org/{orgId}/api-key/{apiKeyId}/actions",
|
||||
description:
|
||||
"Set actions for an API key. This will replace any existing actions.",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.ApiKey],
|
||||
tags: [OpenAPITags.ApiKey],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
body: {
|
||||
|
||||
@@ -20,7 +20,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/logs/request",
|
||||
description: "Query the request audit log for an organization",
|
||||
tags: [OpenAPITags.Org],
|
||||
tags: [OpenAPITags.Logs],
|
||||
request: {
|
||||
query: queryAccessAuditLogsQuery.omit({
|
||||
limit: true,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { db, requestAuditLog, driver } from "@server/db";
|
||||
import { logsDb, requestAuditLog, driver, primaryLogsDb } from "@server/db";
|
||||
import { registry } from "@server/openApi";
|
||||
import { NextFunction } from "express";
|
||||
import { Request, Response } from "express";
|
||||
@@ -35,7 +35,7 @@ const queryAccessAuditLogsQuery = z.object({
|
||||
})
|
||||
.transform((val) => Math.floor(new Date(val).getTime() / 1000))
|
||||
.optional()
|
||||
.prefault(new Date().toISOString())
|
||||
.prefault(() => new Date().toISOString())
|
||||
.openapi({
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
@@ -74,12 +74,12 @@ async function query(query: Q) {
|
||||
);
|
||||
}
|
||||
|
||||
const [all] = await db
|
||||
const [all] = await primaryLogsDb
|
||||
.select({ total: count() })
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions);
|
||||
|
||||
const [blocked] = await db
|
||||
const [blocked] = await primaryLogsDb
|
||||
.select({ total: count() })
|
||||
.from(requestAuditLog)
|
||||
.where(and(baseConditions, eq(requestAuditLog.action, false)));
|
||||
@@ -88,7 +88,9 @@ async function query(query: Q) {
|
||||
.mapWith(Number)
|
||||
.as("total");
|
||||
|
||||
const requestsPerCountry = await db
|
||||
const DISTINCT_LIMIT = 500;
|
||||
|
||||
const requestsPerCountry = await primaryLogsDb
|
||||
.selectDistinct({
|
||||
code: requestAuditLog.location,
|
||||
count: totalQ
|
||||
@@ -96,7 +98,17 @@ async function query(query: Q) {
|
||||
.from(requestAuditLog)
|
||||
.where(and(baseConditions, not(isNull(requestAuditLog.location))))
|
||||
.groupBy(requestAuditLog.location)
|
||||
.orderBy(desc(totalQ));
|
||||
.orderBy(desc(totalQ))
|
||||
.limit(DISTINCT_LIMIT + 1);
|
||||
|
||||
if (requestsPerCountry.length > DISTINCT_LIMIT) {
|
||||
// throw an error
|
||||
throw createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
// todo: is this even possible?
|
||||
`Too many distinct countries. Please narrow your query.`
|
||||
);
|
||||
}
|
||||
|
||||
const groupByDayFunction =
|
||||
driver === "pg"
|
||||
@@ -106,7 +118,7 @@ async function query(query: Q) {
|
||||
const booleanTrue = driver === "pg" ? sql`true` : sql`1`;
|
||||
const booleanFalse = driver === "pg" ? sql`false` : sql`0`;
|
||||
|
||||
const requestsPerDay = await db
|
||||
const requestsPerDay = await primaryLogsDb
|
||||
.select({
|
||||
day: groupByDayFunction.as("day"),
|
||||
allowedCount:
|
||||
@@ -139,7 +151,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/logs/analytics",
|
||||
description: "Query the request audit analytics for an organization",
|
||||
tags: [OpenAPITags.Org],
|
||||
tags: [OpenAPITags.Logs],
|
||||
request: {
|
||||
query: queryAccessAuditLogsQuery,
|
||||
params: queryRequestAuditLogsParams
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { db, requestAuditLog, resources } from "@server/db";
|
||||
import { logsDb, primaryLogsDb, requestAuditLog, resources, db, primaryDb } from "@server/db";
|
||||
import { registry } from "@server/openApi";
|
||||
import { NextFunction } from "express";
|
||||
import { Request, Response } from "express";
|
||||
import { eq, gt, lt, and, count, desc } from "drizzle-orm";
|
||||
import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm";
|
||||
import { OpenAPITags } from "@server/openApi";
|
||||
import { z } from "zod";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -35,7 +35,7 @@ export const queryAccessAuditLogsQuery = z.object({
|
||||
})
|
||||
.transform((val) => Math.floor(new Date(val).getTime() / 1000))
|
||||
.optional()
|
||||
.prefault(new Date().toISOString())
|
||||
.prefault(() => new Date().toISOString())
|
||||
.openapi({
|
||||
type: "string",
|
||||
format: "date-time",
|
||||
@@ -107,7 +107,7 @@ function getWhere(data: Q) {
|
||||
}
|
||||
|
||||
export function queryRequest(data: Q) {
|
||||
return db
|
||||
return primaryLogsDb
|
||||
.select({
|
||||
id: requestAuditLog.id,
|
||||
timestamp: requestAuditLog.timestamp,
|
||||
@@ -129,21 +129,49 @@ export function queryRequest(data: Q) {
|
||||
host: requestAuditLog.host,
|
||||
path: requestAuditLog.path,
|
||||
method: requestAuditLog.method,
|
||||
tls: requestAuditLog.tls,
|
||||
resourceName: resources.name,
|
||||
resourceNiceId: resources.niceId
|
||||
tls: requestAuditLog.tls
|
||||
})
|
||||
.from(requestAuditLog)
|
||||
.leftJoin(
|
||||
resources,
|
||||
eq(requestAuditLog.resourceId, resources.resourceId)
|
||||
) // TODO: Is this efficient?
|
||||
.where(getWhere(data))
|
||||
.orderBy(desc(requestAuditLog.timestamp));
|
||||
}
|
||||
|
||||
async function enrichWithResourceDetails(logs: Awaited<ReturnType<typeof queryRequest>>) {
|
||||
// If logs database is the same as main database, we can do a join
|
||||
// Otherwise, we need to fetch resource details separately
|
||||
const resourceIds = logs
|
||||
.map(log => log.resourceId)
|
||||
.filter((id): id is number => id !== null && id !== undefined);
|
||||
|
||||
if (resourceIds.length === 0) {
|
||||
return logs.map(log => ({ ...log, resourceName: null, resourceNiceId: null }));
|
||||
}
|
||||
|
||||
// Fetch resource details from main database
|
||||
const resourceDetails = await primaryDb
|
||||
.select({
|
||||
resourceId: resources.resourceId,
|
||||
name: resources.name,
|
||||
niceId: resources.niceId
|
||||
})
|
||||
.from(resources)
|
||||
.where(inArray(resources.resourceId, resourceIds));
|
||||
|
||||
// Create a map for quick lookup
|
||||
const resourceMap = new Map(
|
||||
resourceDetails.map(r => [r.resourceId, { name: r.name, niceId: r.niceId }])
|
||||
);
|
||||
|
||||
// Enrich logs with resource details
|
||||
return logs.map(log => ({
|
||||
...log,
|
||||
resourceName: log.resourceId ? resourceMap.get(log.resourceId)?.name ?? null : null,
|
||||
resourceNiceId: log.resourceId ? resourceMap.get(log.resourceId)?.niceId ?? null : null
|
||||
}));
|
||||
}
|
||||
|
||||
export function countRequestQuery(data: Q) {
|
||||
const countQuery = db
|
||||
const countQuery = primaryLogsDb
|
||||
.select({ count: count() })
|
||||
.from(requestAuditLog)
|
||||
.where(getWhere(data));
|
||||
@@ -154,7 +182,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/logs/request",
|
||||
description: "Query the request audit log for an organization",
|
||||
tags: [OpenAPITags.Org],
|
||||
tags: [OpenAPITags.Logs],
|
||||
request: {
|
||||
query: queryAccessAuditLogsQuery,
|
||||
params: queryRequestAuditLogsParams
|
||||
@@ -173,58 +201,86 @@ async function queryUniqueFilterAttributes(
|
||||
eq(requestAuditLog.orgId, orgId)
|
||||
);
|
||||
|
||||
// Get unique actors
|
||||
const uniqueActors = await db
|
||||
.selectDistinct({
|
||||
actor: requestAuditLog.actor
|
||||
})
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions);
|
||||
const DISTINCT_LIMIT = 500;
|
||||
|
||||
// Get unique locations
|
||||
const uniqueLocations = await db
|
||||
.selectDistinct({
|
||||
locations: requestAuditLog.location
|
||||
})
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions);
|
||||
// TODO: SOMEONE PLEASE OPTIMIZE THIS!!!!!
|
||||
|
||||
// Get unique actors
|
||||
const uniqueHosts = await db
|
||||
.selectDistinct({
|
||||
hosts: requestAuditLog.host
|
||||
})
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions);
|
||||
// Run all queries in parallel
|
||||
const [
|
||||
uniqueActors,
|
||||
uniqueLocations,
|
||||
uniqueHosts,
|
||||
uniquePaths,
|
||||
uniqueResources
|
||||
] = await Promise.all([
|
||||
primaryLogsDb
|
||||
.selectDistinct({ actor: requestAuditLog.actor })
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions)
|
||||
.limit(DISTINCT_LIMIT + 1),
|
||||
primaryLogsDb
|
||||
.selectDistinct({ locations: requestAuditLog.location })
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions)
|
||||
.limit(DISTINCT_LIMIT + 1),
|
||||
primaryLogsDb
|
||||
.selectDistinct({ hosts: requestAuditLog.host })
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions)
|
||||
.limit(DISTINCT_LIMIT + 1),
|
||||
primaryLogsDb
|
||||
.selectDistinct({ paths: requestAuditLog.path })
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions)
|
||||
.limit(DISTINCT_LIMIT + 1),
|
||||
primaryLogsDb
|
||||
.selectDistinct({
|
||||
id: requestAuditLog.resourceId
|
||||
})
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions)
|
||||
.limit(DISTINCT_LIMIT + 1)
|
||||
]);
|
||||
|
||||
// Get unique actors
|
||||
const uniquePaths = await db
|
||||
.selectDistinct({
|
||||
paths: requestAuditLog.path
|
||||
})
|
||||
.from(requestAuditLog)
|
||||
.where(baseConditions);
|
||||
// TODO: for stuff like the paths this is too restrictive so lets just show some of the paths and the user needs to
|
||||
// refine the time range to see what they need to see
|
||||
// if (
|
||||
// uniqueActors.length > DISTINCT_LIMIT ||
|
||||
// uniqueLocations.length > DISTINCT_LIMIT ||
|
||||
// uniqueHosts.length > DISTINCT_LIMIT ||
|
||||
// uniquePaths.length > DISTINCT_LIMIT ||
|
||||
// uniqueResources.length > DISTINCT_LIMIT
|
||||
// ) {
|
||||
// throw new Error("Too many distinct filter attributes to retrieve. Please refine your time range.");
|
||||
// }
|
||||
|
||||
// Get unique resources with names
|
||||
const uniqueResources = await db
|
||||
.selectDistinct({
|
||||
id: requestAuditLog.resourceId,
|
||||
name: resources.name
|
||||
})
|
||||
.from(requestAuditLog)
|
||||
.leftJoin(
|
||||
resources,
|
||||
eq(requestAuditLog.resourceId, resources.resourceId)
|
||||
)
|
||||
.where(baseConditions);
|
||||
// Fetch resource names from main database for the unique resource IDs
|
||||
const resourceIds = uniqueResources
|
||||
.map(row => row.id)
|
||||
.filter((id): id is number => id !== null);
|
||||
|
||||
let resourcesWithNames: Array<{ id: number; name: string | null }> = [];
|
||||
|
||||
if (resourceIds.length > 0) {
|
||||
const resourceDetails = await primaryDb
|
||||
.select({
|
||||
resourceId: resources.resourceId,
|
||||
name: resources.name
|
||||
})
|
||||
.from(resources)
|
||||
.where(inArray(resources.resourceId, resourceIds));
|
||||
|
||||
resourcesWithNames = resourceDetails.map(r => ({
|
||||
id: r.resourceId,
|
||||
name: r.name
|
||||
}));
|
||||
}
|
||||
|
||||
return {
|
||||
actors: uniqueActors
|
||||
.map((row) => row.actor)
|
||||
.filter((actor): actor is string => actor !== null),
|
||||
resources: uniqueResources.filter(
|
||||
(row): row is { id: number; name: string | null } => row.id !== null
|
||||
),
|
||||
resources: resourcesWithNames,
|
||||
locations: uniqueLocations
|
||||
.map((row) => row.locations)
|
||||
.filter((location): location is string => location !== null),
|
||||
@@ -267,7 +323,10 @@ export async function queryRequestAuditLogs(
|
||||
|
||||
const baseQuery = queryRequest(data);
|
||||
|
||||
const log = await baseQuery.limit(data.limit).offset(data.offset);
|
||||
const logsRaw = await baseQuery.limit(data.limit).offset(data.offset);
|
||||
|
||||
// Enrich with resource details (handles cross-database scenario)
|
||||
const log = await enrichWithResourceDetails(logsRaw);
|
||||
|
||||
const totalCountResult = await countRequestQuery(data);
|
||||
const totalCount = totalCountResult[0].count;
|
||||
@@ -295,6 +354,14 @@ export async function queryRequestAuditLogs(
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
// if the message is "Too many distinct filter attributes to retrieve. Please refine your time range.", return a 400 and the message
|
||||
if (
|
||||
error instanceof Error &&
|
||||
error.message ===
|
||||
"Too many distinct filter attributes to retrieve. Please refine your time range."
|
||||
) {
|
||||
return next(createHttpError(HttpCode.BAD_REQUEST, error.message));
|
||||
}
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
|
||||
@@ -91,3 +91,50 @@ export type QueryAccessAuditLogResponse = {
|
||||
locations: string[];
|
||||
};
|
||||
};
|
||||
|
||||
export type QueryConnectionAuditLogResponse = {
|
||||
log: {
|
||||
sessionId: string;
|
||||
siteResourceId: number | null;
|
||||
orgId: string | null;
|
||||
siteId: number | null;
|
||||
clientId: number | null;
|
||||
userId: string | null;
|
||||
sourceAddr: string;
|
||||
destAddr: string;
|
||||
protocol: string;
|
||||
startedAt: number;
|
||||
endedAt: number | null;
|
||||
bytesTx: number | null;
|
||||
bytesRx: number | null;
|
||||
resourceName: string | null;
|
||||
resourceNiceId: string | null;
|
||||
siteName: string | null;
|
||||
siteNiceId: string | null;
|
||||
clientName: string | null;
|
||||
clientNiceId: string | null;
|
||||
clientType: string | null;
|
||||
userEmail: string | null;
|
||||
}[];
|
||||
pagination: {
|
||||
total: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
};
|
||||
filterAttributes: {
|
||||
protocols: string[];
|
||||
destAddrs: string[];
|
||||
clients: {
|
||||
id: number;
|
||||
name: string;
|
||||
}[];
|
||||
resources: {
|
||||
id: number;
|
||||
name: string | null;
|
||||
}[];
|
||||
users: {
|
||||
id: string;
|
||||
email: string | null;
|
||||
}[];
|
||||
};
|
||||
};
|
||||
|
||||
242
server/routers/auth/deleteMyAccount.ts
Normal file
242
server/routers/auth/deleteMyAccount.ts
Normal file
@@ -0,0 +1,242 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, orgs, userOrgs, users } from "@server/db";
|
||||
import { eq, and, inArray } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { verifySession } from "@server/auth/sessions/verifySession";
|
||||
import {
|
||||
invalidateSession,
|
||||
createBlankSessionTokenCookie
|
||||
} from "@server/auth/sessions/app";
|
||||
import { verifyPassword } from "@server/auth/password";
|
||||
import { verifyTotpCode } from "@server/auth/totp";
|
||||
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
|
||||
import { build } from "@server/build";
|
||||
import { getOrgTierData } from "#dynamic/lib/billing";
|
||||
import {
|
||||
deleteOrgById,
|
||||
sendTerminationMessages
|
||||
} from "@server/lib/deleteOrg";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
|
||||
const deleteMyAccountBody = z.strictObject({
|
||||
password: z.string().optional(),
|
||||
code: z.string().optional()
|
||||
});
|
||||
|
||||
export type DeleteMyAccountPreviewResponse = {
|
||||
preview: true;
|
||||
orgs: { orgId: string; name: string }[];
|
||||
twoFactorEnabled: boolean;
|
||||
};
|
||||
|
||||
export type DeleteMyAccountCodeRequestedResponse = {
|
||||
codeRequested: true;
|
||||
};
|
||||
|
||||
export type DeleteMyAccountSuccessResponse = {
|
||||
success: true;
|
||||
};
|
||||
|
||||
export async function deleteMyAccount(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const { user, session } = await verifySession(req);
|
||||
if (!user || !session) {
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "Not authenticated")
|
||||
);
|
||||
}
|
||||
|
||||
if (user.serverAdmin) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Server admins cannot delete their account this way"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (user.type !== UserType.Internal) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Account deletion with password is only supported for internal users"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const parsed = deleteMyAccountBody.safeParse(req.body ?? {});
|
||||
if (!parsed.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsed.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
const { password, code } = parsed.data;
|
||||
|
||||
const userId = user.userId;
|
||||
|
||||
const ownedOrgsRows = await db
|
||||
.select({
|
||||
orgId: userOrgs.orgId,
|
||||
isOwner: userOrgs.isOwner,
|
||||
isBillingOrg: orgs.isBillingOrg
|
||||
})
|
||||
.from(userOrgs)
|
||||
.innerJoin(orgs, eq(userOrgs.orgId, orgs.orgId))
|
||||
.where(
|
||||
and(eq(userOrgs.userId, userId), eq(userOrgs.isOwner, true))
|
||||
);
|
||||
|
||||
const orgIds = ownedOrgsRows.map((r) => r.orgId);
|
||||
|
||||
if (build === "saas" && orgIds.length > 0) {
|
||||
const primaryOrgId = ownedOrgsRows.find(
|
||||
(r) => r.isBillingOrg && r.isOwner
|
||||
)?.orgId;
|
||||
if (primaryOrgId) {
|
||||
const { tier, active } = await getOrgTierData(primaryOrgId);
|
||||
if (active && tier) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"You must cancel your subscription before deleting your account"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!password) {
|
||||
const orgsWithNames =
|
||||
orgIds.length > 0
|
||||
? await db
|
||||
.select({
|
||||
orgId: orgs.orgId,
|
||||
name: orgs.name
|
||||
})
|
||||
.from(orgs)
|
||||
.where(inArray(orgs.orgId, orgIds))
|
||||
: [];
|
||||
return response<DeleteMyAccountPreviewResponse>(res, {
|
||||
data: {
|
||||
preview: true,
|
||||
orgs: orgsWithNames.map((o) => ({
|
||||
orgId: o.orgId,
|
||||
name: o.name ?? ""
|
||||
})),
|
||||
twoFactorEnabled: user.twoFactorEnabled ?? false
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Preview",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
const validPassword = await verifyPassword(
|
||||
password,
|
||||
user.passwordHash!
|
||||
);
|
||||
if (!validPassword) {
|
||||
return next(
|
||||
createHttpError(HttpCode.UNAUTHORIZED, "Invalid password")
|
||||
);
|
||||
}
|
||||
|
||||
if (user.twoFactorEnabled) {
|
||||
if (!code) {
|
||||
return response<DeleteMyAccountCodeRequestedResponse>(res, {
|
||||
data: { codeRequested: true },
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Two-factor code required",
|
||||
status: HttpCode.ACCEPTED
|
||||
});
|
||||
}
|
||||
const validOTP = await verifyTotpCode(
|
||||
code,
|
||||
user.twoFactorSecret!,
|
||||
user.userId
|
||||
);
|
||||
if (!validOTP) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"The two-factor code you entered is incorrect"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const allDeletedNewtIds: string[] = [];
|
||||
const allOlmsToTerminate: string[] = [];
|
||||
|
||||
for (const row of ownedOrgsRows) {
|
||||
try {
|
||||
const result = await deleteOrgById(row.orgId);
|
||||
allDeletedNewtIds.push(...result.deletedNewtIds);
|
||||
allOlmsToTerminate.push(...result.olmsToTerminate);
|
||||
} catch (err) {
|
||||
logger.error(
|
||||
`Failed to delete org ${row.orgId} during account deletion`,
|
||||
err
|
||||
);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to delete organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
sendTerminationMessages({
|
||||
deletedNewtIds: allDeletedNewtIds,
|
||||
olmsToTerminate: allOlmsToTerminate
|
||||
});
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
await trx.delete(users).where(eq(users.userId, userId));
|
||||
await calculateUserClientsForOrgs(userId, trx);
|
||||
});
|
||||
|
||||
try {
|
||||
await invalidateSession(session.sessionId);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
"Failed to invalidate session after account deletion",
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
const isSecure = req.protocol === "https";
|
||||
res.setHeader("Set-Cookie", createBlankSessionTokenCookie(isSecure));
|
||||
|
||||
return response<DeleteMyAccountSuccessResponse>(res, {
|
||||
data: { success: true },
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Account deleted successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"An error occurred"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -17,3 +17,5 @@ export * from "./securityKey";
|
||||
export * from "./startDeviceWebAuth";
|
||||
export * from "./verifyDeviceWebAuth";
|
||||
export * from "./pollDeviceWebAuth";
|
||||
export * from "./lookupUser";
|
||||
export * from "./deleteMyAccount";
|
||||
224
server/routers/auth/lookupUser.ts
Normal file
224
server/routers/auth/lookupUser.ts
Normal file
@@ -0,0 +1,224 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import {
|
||||
users,
|
||||
userOrgs,
|
||||
orgs,
|
||||
idpOrg,
|
||||
idp,
|
||||
idpOidcConfig
|
||||
} from "@server/db";
|
||||
import { eq, or, sql, and, isNotNull, inArray } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
|
||||
const lookupBodySchema = z.strictObject({
|
||||
identifier: z.string().min(1).toLowerCase()
|
||||
});
|
||||
|
||||
export type LookupUserResponse = {
|
||||
found: boolean;
|
||||
identifier: string;
|
||||
accounts: Array<{
|
||||
userId: string;
|
||||
email: string | null;
|
||||
username: string;
|
||||
hasInternalAuth: boolean;
|
||||
orgs: Array<{
|
||||
orgId: string;
|
||||
orgName: string;
|
||||
idps: Array<{
|
||||
idpId: number;
|
||||
name: string;
|
||||
variant: string | null;
|
||||
}>;
|
||||
hasInternalAuth: boolean;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "post",
|
||||
// path: "/auth/lookup-user",
|
||||
// description: "Lookup user accounts by username or email and return available authentication methods.",
|
||||
// tags: [OpenAPITags.Auth],
|
||||
// request: {
|
||||
// body: lookupBodySchema
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export async function lookupUser(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedBody = lookupBodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { identifier } = parsedBody.data;
|
||||
|
||||
// Query users matching identifier (case-insensitive)
|
||||
// Match by username OR email
|
||||
const matchingUsers = await db
|
||||
.select({
|
||||
userId: users.userId,
|
||||
email: users.email,
|
||||
username: users.username,
|
||||
type: users.type,
|
||||
passwordHash: users.passwordHash,
|
||||
idpId: users.idpId
|
||||
})
|
||||
.from(users)
|
||||
.where(
|
||||
or(
|
||||
sql`LOWER(${users.username}) = ${identifier}`,
|
||||
sql`LOWER(${users.email}) = ${identifier}`
|
||||
)
|
||||
);
|
||||
|
||||
if (!matchingUsers || matchingUsers.length === 0) {
|
||||
return response<LookupUserResponse>(res, {
|
||||
data: {
|
||||
found: false,
|
||||
identifier,
|
||||
accounts: []
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "No accounts found",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
// Get unique user IDs
|
||||
const userIds = [...new Set(matchingUsers.map((u) => u.userId))];
|
||||
|
||||
// Get all org memberships for these users
|
||||
const orgMemberships = await db
|
||||
.select({
|
||||
userId: userOrgs.userId,
|
||||
orgId: userOrgs.orgId,
|
||||
orgName: orgs.name
|
||||
})
|
||||
.from(userOrgs)
|
||||
.innerJoin(orgs, eq(orgs.orgId, userOrgs.orgId))
|
||||
.where(inArray(userOrgs.userId, userIds));
|
||||
|
||||
// Get unique org IDs
|
||||
const orgIds = [...new Set(orgMemberships.map((m) => m.orgId))];
|
||||
|
||||
// Get all IdPs for these orgs
|
||||
const orgIdps =
|
||||
orgIds.length > 0
|
||||
? await db
|
||||
.select({
|
||||
orgId: idpOrg.orgId,
|
||||
idpId: idp.idpId,
|
||||
idpName: idp.name,
|
||||
variant: idpOidcConfig.variant
|
||||
})
|
||||
.from(idpOrg)
|
||||
.innerJoin(idp, eq(idp.idpId, idpOrg.idpId))
|
||||
.innerJoin(
|
||||
idpOidcConfig,
|
||||
eq(idpOidcConfig.idpId, idp.idpId)
|
||||
)
|
||||
.where(inArray(idpOrg.orgId, orgIds))
|
||||
: [];
|
||||
|
||||
// Build response structure
|
||||
const accounts: LookupUserResponse["accounts"] = [];
|
||||
|
||||
for (const user of matchingUsers) {
|
||||
const hasInternalAuth =
|
||||
user.type === UserType.Internal && user.passwordHash !== null;
|
||||
|
||||
// Get orgs for this user
|
||||
const userOrgMemberships = orgMemberships.filter(
|
||||
(m) => m.userId === user.userId
|
||||
);
|
||||
|
||||
// Deduplicate orgs (user might have multiple memberships in same org)
|
||||
const uniqueOrgs = new Map<string, typeof userOrgMemberships[0]>();
|
||||
for (const membership of userOrgMemberships) {
|
||||
if (!uniqueOrgs.has(membership.orgId)) {
|
||||
uniqueOrgs.set(membership.orgId, membership);
|
||||
}
|
||||
}
|
||||
|
||||
const orgsData = Array.from(uniqueOrgs.values()).map((membership) => {
|
||||
// Get IdPs for this org where the user (with the exact identifier) is authenticated via that IdP
|
||||
// Only show IdPs where the user's idpId matches
|
||||
// Internal users don't have an idpId, so they won't see any IdPs
|
||||
const orgIdpsList = orgIdps
|
||||
.filter((idp) => {
|
||||
if (idp.orgId !== membership.orgId) {
|
||||
return false;
|
||||
}
|
||||
// Only show IdPs where the user (with exact identifier) is authenticated via that IdP
|
||||
// This means user.idpId must match idp.idpId
|
||||
if (user.idpId !== null && user.idpId === idp.idpId) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
})
|
||||
.map((idp) => ({
|
||||
idpId: idp.idpId,
|
||||
name: idp.idpName,
|
||||
variant: idp.variant
|
||||
}));
|
||||
|
||||
// Check if user has internal auth for this org
|
||||
// User has internal auth if they have an internal account type
|
||||
const orgHasInternalAuth = hasInternalAuth;
|
||||
|
||||
return {
|
||||
orgId: membership.orgId,
|
||||
orgName: membership.orgName,
|
||||
idps: orgIdpsList,
|
||||
hasInternalAuth: orgHasInternalAuth
|
||||
};
|
||||
});
|
||||
|
||||
accounts.push({
|
||||
userId: user.userId,
|
||||
email: user.email,
|
||||
username: user.username,
|
||||
hasInternalAuth,
|
||||
orgs: orgsData
|
||||
});
|
||||
}
|
||||
|
||||
return response<LookupUserResponse>(res, {
|
||||
data: {
|
||||
found: true,
|
||||
identifier,
|
||||
accounts
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "User lookup completed",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import { eq, and, gt } from "drizzle-orm";
|
||||
import { createSession, generateSessionToken } from "@server/auth/sessions/app";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { stripPortFromHost } from "@server/lib/ip";
|
||||
|
||||
const paramsSchema = z.object({
|
||||
code: z.string().min(1, "Code is required")
|
||||
@@ -27,30 +28,6 @@ export type PollDeviceWebAuthResponse = {
|
||||
token?: string;
|
||||
};
|
||||
|
||||
// Helper function to extract IP from request (same as in startDeviceWebAuth)
|
||||
function extractIpFromRequest(req: Request): string | undefined {
|
||||
const ip = req.ip || req.socket.remoteAddress;
|
||||
if (!ip) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Handle IPv6 format [::1] or IPv4 format
|
||||
if (ip.startsWith("[") && ip.includes("]")) {
|
||||
const ipv6Match = ip.match(/\[(.*?)\]/);
|
||||
if (ipv6Match) {
|
||||
return ipv6Match[1];
|
||||
}
|
||||
}
|
||||
|
||||
// Handle IPv4 with port (split at last colon)
|
||||
const lastColonIndex = ip.lastIndexOf(":");
|
||||
if (lastColonIndex !== -1) {
|
||||
return ip.substring(0, lastColonIndex);
|
||||
}
|
||||
|
||||
return ip;
|
||||
}
|
||||
|
||||
export async function pollDeviceWebAuth(
|
||||
req: Request,
|
||||
res: Response,
|
||||
@@ -70,7 +47,7 @@ export async function pollDeviceWebAuth(
|
||||
try {
|
||||
const { code } = parsedParams.data;
|
||||
const now = Date.now();
|
||||
const requestIp = extractIpFromRequest(req);
|
||||
const requestIp = req.ip ? stripPortFromHost(req.ip) : undefined;
|
||||
|
||||
// Hash the code before querying
|
||||
const hashedCode = hashDeviceCode(code);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db, users } from "@server/db";
|
||||
import { bannedEmails, bannedIps, db, users } from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { z } from "zod";
|
||||
import { email, z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
@@ -21,9 +21,7 @@ import { hashPassword } from "@server/auth/password";
|
||||
import { checkValidInvite } from "@server/auth/checkValidInvite";
|
||||
import { passwordSchema } from "@server/auth/passwordSchema";
|
||||
import { UserType } from "@server/types/UserTypes";
|
||||
import { createUserAccountOrg } from "@server/lib/createUserAccountOrg";
|
||||
import { build } from "@server/build";
|
||||
import resend, { AudienceIds, moveEmailToAudience } from "#dynamic/lib/resend";
|
||||
|
||||
export const signupBodySchema = z.object({
|
||||
email: z.email().toLowerCase(),
|
||||
@@ -31,7 +29,8 @@ export const signupBodySchema = z.object({
|
||||
inviteToken: z.string().optional(),
|
||||
inviteId: z.string().optional(),
|
||||
termsAcceptedTimestamp: z.string().nullable().optional(),
|
||||
marketingEmailConsent: z.boolean().optional()
|
||||
marketingEmailConsent: z.boolean().optional(),
|
||||
skipVerificationEmail: z.boolean().optional()
|
||||
});
|
||||
|
||||
export type SignUpBody = z.infer<typeof signupBodySchema>;
|
||||
@@ -62,9 +61,34 @@ export async function signup(
|
||||
inviteToken,
|
||||
inviteId,
|
||||
termsAcceptedTimestamp,
|
||||
marketingEmailConsent
|
||||
marketingEmailConsent,
|
||||
skipVerificationEmail
|
||||
} = parsedBody.data;
|
||||
|
||||
const [bannedEmail] = await db
|
||||
.select()
|
||||
.from(bannedEmails)
|
||||
.where(eq(bannedEmails.email, email))
|
||||
.limit(1);
|
||||
if (bannedEmail) {
|
||||
return next(
|
||||
createHttpError(HttpCode.FORBIDDEN, "Signup blocked. Do not attempt to continue to use this service.")
|
||||
);
|
||||
}
|
||||
|
||||
if (req.ip) {
|
||||
const [bannedIp] = await db
|
||||
.select()
|
||||
.from(bannedIps)
|
||||
.where(eq(bannedIps.ip, req.ip))
|
||||
.limit(1);
|
||||
if (bannedIp) {
|
||||
return next(
|
||||
createHttpError(HttpCode.FORBIDDEN, "Signup blocked. Do not attempt to continue to use this service.")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const passwordHash = await hashPassword(password);
|
||||
const userId = generateId(15);
|
||||
|
||||
@@ -188,6 +212,7 @@ export async function signup(
|
||||
dateCreated: moment().toISOString(),
|
||||
termsAcceptedTimestamp: termsAcceptedTimestamp || null,
|
||||
termsVersion: "1",
|
||||
marketingEmailConsent: marketingEmailConsent ?? false,
|
||||
lastPasswordChange: new Date().getTime()
|
||||
});
|
||||
|
||||
@@ -198,26 +223,6 @@ export async function signup(
|
||||
// orgId: null,
|
||||
// });
|
||||
|
||||
if (build == "saas") {
|
||||
const { success, error, org } = await createUserAccountOrg(
|
||||
userId,
|
||||
email
|
||||
);
|
||||
if (!success) {
|
||||
if (error) {
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, error)
|
||||
);
|
||||
}
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to create user account and organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const token = generateSessionToken();
|
||||
const sess = await createSession(token, userId);
|
||||
const isSecure = req.protocol === "https";
|
||||
@@ -231,11 +236,17 @@ export async function signup(
|
||||
logger.debug(
|
||||
`User ${email} opted in to marketing emails during signup.`
|
||||
);
|
||||
moveEmailToAudience(email, AudienceIds.SignUps);
|
||||
// TODO: update user in Sendy
|
||||
}
|
||||
|
||||
if (config.getRawConfig().flags?.require_email_verification) {
|
||||
sendEmailVerificationCode(email, userId);
|
||||
if (!skipVerificationEmail) {
|
||||
sendEmailVerificationCode(email, userId);
|
||||
} else {
|
||||
logger.debug(
|
||||
`User ${email} opted out of verification email during signup.`
|
||||
);
|
||||
}
|
||||
|
||||
return response<SignUpResponse>(res, {
|
||||
data: {
|
||||
@@ -243,7 +254,9 @@ export async function signup(
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: `User created successfully. We sent an email to ${email} with a verification code.`,
|
||||
message: skipVerificationEmail
|
||||
? "User created successfully. Please verify your email."
|
||||
: `User created successfully. We sent an email to ${email} with a verification code.`,
|
||||
status: HttpCode.OK
|
||||
});
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import { TimeSpan } from "oslo";
|
||||
import { maxmindLookup } from "@server/db/maxmind";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { stripPortFromHost } from "@server/lib/ip";
|
||||
|
||||
const bodySchema = z
|
||||
.object({
|
||||
@@ -39,30 +40,6 @@ function hashDeviceCode(code: string): string {
|
||||
return encodeHexLowerCase(sha256(new TextEncoder().encode(code)));
|
||||
}
|
||||
|
||||
// Helper function to extract IP from request
|
||||
function extractIpFromRequest(req: Request): string | undefined {
|
||||
const ip = req.ip;
|
||||
if (!ip) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Handle IPv6 format [::1] or IPv4 format
|
||||
if (ip.startsWith("[") && ip.includes("]")) {
|
||||
const ipv6Match = ip.match(/\[(.*?)\]/);
|
||||
if (ipv6Match) {
|
||||
return ipv6Match[1];
|
||||
}
|
||||
}
|
||||
|
||||
// Handle IPv4 with port (split at last colon)
|
||||
const lastColonIndex = ip.lastIndexOf(":");
|
||||
if (lastColonIndex !== -1) {
|
||||
return ip.substring(0, lastColonIndex);
|
||||
}
|
||||
|
||||
return ip;
|
||||
}
|
||||
|
||||
// Helper function to get city from IP (if available)
|
||||
async function getCityFromIp(ip: string): Promise<string | undefined> {
|
||||
try {
|
||||
@@ -112,7 +89,7 @@ export async function startDeviceWebAuth(
|
||||
const hashedCode = hashDeviceCode(code);
|
||||
|
||||
// Extract IP from request
|
||||
const ip = extractIpFromRequest(req);
|
||||
const ip = req.ip ? stripPortFromHost(req.ip) : undefined;
|
||||
|
||||
// Get city (optional, may return undefined)
|
||||
const city = ip ? await getCityFromIp(ip) : undefined;
|
||||
|
||||
@@ -10,6 +10,7 @@ import { eq, and, gt } from "drizzle-orm";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { unauthorized } from "@server/auth/unauthorizedResponse";
|
||||
import { getIosDeviceName, getMacDeviceName } from "@server/db/names";
|
||||
|
||||
const bodySchema = z
|
||||
.object({
|
||||
@@ -120,6 +121,11 @@ export async function verifyDeviceWebAuth(
|
||||
);
|
||||
}
|
||||
|
||||
const deviceName =
|
||||
getMacDeviceName(deviceCode.deviceName) ||
|
||||
getIosDeviceName(deviceCode.deviceName) ||
|
||||
deviceCode.deviceName;
|
||||
|
||||
// If verify is false, just return metadata without verifying
|
||||
if (!verify) {
|
||||
return response<VerifyDeviceWebAuthResponse>(res, {
|
||||
@@ -129,7 +135,7 @@ export async function verifyDeviceWebAuth(
|
||||
metadata: {
|
||||
ip: deviceCode.ip,
|
||||
city: deviceCode.city,
|
||||
deviceName: deviceCode.deviceName,
|
||||
deviceName: deviceName,
|
||||
applicationName: deviceCode.applicationName,
|
||||
createdAt: deviceCode.createdAt
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
import { SESSION_COOKIE_EXPIRES as RESOURCE_SESSION_COOKIE_EXPIRES } from "@server/auth/sessions/resource";
|
||||
import config from "@server/lib/config";
|
||||
import { response } from "@server/lib/response";
|
||||
import { stripPortFromHost } from "@server/lib/ip";
|
||||
|
||||
const exchangeSessionBodySchema = z.object({
|
||||
requestToken: z.string(),
|
||||
@@ -62,7 +63,7 @@ export async function exchangeSession(
|
||||
cleanHost = cleanHost.slice(0, -1 * matched.length);
|
||||
}
|
||||
|
||||
const clientIp = requestIp?.split(":")[0];
|
||||
const clientIp = requestIp ? stripPortFromHost(requestIp) : undefined;
|
||||
|
||||
const [resource] = await db
|
||||
.select()
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import { db, orgs, requestAuditLog } from "@server/db";
|
||||
import { logsDb, primaryLogsDb, db, orgs, requestAuditLog } from "@server/db";
|
||||
import logger from "@server/logger";
|
||||
import { and, eq, lt } from "drizzle-orm";
|
||||
import cache from "@server/lib/cache";
|
||||
import { and, eq, lt, sql } from "drizzle-orm";
|
||||
import cache from "#dynamic/lib/cache";
|
||||
import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs";
|
||||
import { stripPortFromHost } from "@server/lib/ip";
|
||||
|
||||
import { sanitizeString } from "@server/lib/sanitize";
|
||||
|
||||
/**
|
||||
|
||||
@@ -10,7 +13,7 @@ Reasons:
|
||||
100 - Allowed by Rule
|
||||
101 - Allowed No Auth
|
||||
102 - Valid Access Token
|
||||
103 - Valid header auth
|
||||
103 - Valid Header Auth (HTTP Basic Auth)
|
||||
104 - Valid Pincode
|
||||
105 - Valid Password
|
||||
106 - Valid email
|
||||
@@ -48,27 +51,53 @@ const auditLogBuffer: Array<{
|
||||
|
||||
const BATCH_SIZE = 100; // Write to DB every 100 logs
|
||||
const BATCH_INTERVAL_MS = 5000; // Or every 5 seconds, whichever comes first
|
||||
const MAX_BUFFER_SIZE = 10000; // Prevent unbounded memory growth
|
||||
let flushTimer: NodeJS.Timeout | null = null;
|
||||
let isFlushInProgress = false;
|
||||
|
||||
/**
|
||||
* Flush buffered logs to database
|
||||
*/
|
||||
async function flushAuditLogs() {
|
||||
if (auditLogBuffer.length === 0) {
|
||||
if (auditLogBuffer.length === 0 || isFlushInProgress) {
|
||||
return;
|
||||
}
|
||||
|
||||
isFlushInProgress = true;
|
||||
|
||||
// Take all current logs and clear buffer
|
||||
const logsToWrite = auditLogBuffer.splice(0, auditLogBuffer.length);
|
||||
|
||||
try {
|
||||
// Batch insert all logs at once
|
||||
await db.insert(requestAuditLog).values(logsToWrite);
|
||||
// Use a transaction to ensure all inserts succeed or fail together
|
||||
// This prevents index corruption from partial writes
|
||||
await logsDb.transaction(async (tx) => {
|
||||
// Batch insert logs in groups of 25 to avoid overwhelming the database
|
||||
const BATCH_DB_SIZE = 25;
|
||||
for (let i = 0; i < logsToWrite.length; i += BATCH_DB_SIZE) {
|
||||
const batch = logsToWrite.slice(i, i + BATCH_DB_SIZE);
|
||||
await tx.insert(requestAuditLog).values(batch);
|
||||
}
|
||||
});
|
||||
logger.debug(`Flushed ${logsToWrite.length} audit logs to database`);
|
||||
} catch (error) {
|
||||
logger.error("Error flushing audit logs:", error);
|
||||
// On error, we lose these logs - consider a fallback strategy if needed
|
||||
// (e.g., write to file, or put back in buffer with retry limit)
|
||||
// On transaction error, put logs back at the front of the buffer to retry
|
||||
// but only if buffer isn't too large
|
||||
if (auditLogBuffer.length < MAX_BUFFER_SIZE - logsToWrite.length) {
|
||||
auditLogBuffer.unshift(...logsToWrite);
|
||||
logger.info(`Re-queued ${logsToWrite.length} audit logs for retry`);
|
||||
} else {
|
||||
logger.error(`Buffer full, dropped ${logsToWrite.length} audit logs`);
|
||||
}
|
||||
} finally {
|
||||
isFlushInProgress = false;
|
||||
// If buffer filled up while we were flushing, flush again
|
||||
if (auditLogBuffer.length >= BATCH_SIZE) {
|
||||
flushAuditLogs().catch((err) =>
|
||||
logger.error("Error in follow-up flush:", err)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,12 +123,16 @@ export async function shutdownAuditLogger() {
|
||||
clearTimeout(flushTimer);
|
||||
flushTimer = null;
|
||||
}
|
||||
// Force flush even if one is in progress by waiting and retrying
|
||||
while (isFlushInProgress) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
}
|
||||
await flushAuditLogs();
|
||||
}
|
||||
|
||||
async function getRetentionDays(orgId: string): Promise<number> {
|
||||
// check cache first
|
||||
const cached = cache.get<number>(`org_${orgId}_retentionDays`);
|
||||
const cached = await cache.get<number>(`org_${orgId}_retentionDays`);
|
||||
if (cached !== undefined) {
|
||||
return cached;
|
||||
}
|
||||
@@ -118,7 +151,7 @@ async function getRetentionDays(orgId: string): Promise<number> {
|
||||
}
|
||||
|
||||
// store the result in cache
|
||||
cache.set(
|
||||
await cache.set(
|
||||
`org_${orgId}_retentionDays`,
|
||||
org.settingsLogRetentionDaysRequest,
|
||||
300
|
||||
@@ -131,7 +164,7 @@ export async function cleanUpOldLogs(orgId: string, retentionDays: number) {
|
||||
const cutoffTimestamp = calculateCutoffTimestamp(retentionDays);
|
||||
|
||||
try {
|
||||
await db
|
||||
await logsDb
|
||||
.delete(requestAuditLog)
|
||||
.where(
|
||||
and(
|
||||
@@ -208,49 +241,37 @@ export async function logRequestAudit(
|
||||
}
|
||||
|
||||
const clientIp = body.requestIp
|
||||
? (() => {
|
||||
if (
|
||||
body.requestIp.startsWith("[") &&
|
||||
body.requestIp.includes("]")
|
||||
) {
|
||||
// if brackets are found, extract the IPv6 address from between the brackets
|
||||
const ipv6Match = body.requestIp.match(/\[(.*?)\]/);
|
||||
if (ipv6Match) {
|
||||
return ipv6Match[1];
|
||||
}
|
||||
}
|
||||
|
||||
// ivp4
|
||||
// split at last colon
|
||||
const lastColonIndex = body.requestIp.lastIndexOf(":");
|
||||
if (lastColonIndex !== -1) {
|
||||
return body.requestIp.substring(0, lastColonIndex);
|
||||
}
|
||||
return body.requestIp;
|
||||
})()
|
||||
? stripPortFromHost(body.requestIp)
|
||||
: undefined;
|
||||
|
||||
// Prevent unbounded buffer growth - drop oldest entries if buffer is too large
|
||||
if (auditLogBuffer.length >= MAX_BUFFER_SIZE) {
|
||||
const dropped = auditLogBuffer.splice(0, BATCH_SIZE);
|
||||
logger.warn(
|
||||
`Audit log buffer exceeded max size (${MAX_BUFFER_SIZE}), dropped ${dropped.length} oldest entries`
|
||||
);
|
||||
}
|
||||
|
||||
// Add to buffer instead of writing directly to DB
|
||||
auditLogBuffer.push({
|
||||
timestamp,
|
||||
orgId: data.orgId,
|
||||
actorType,
|
||||
actor,
|
||||
actorId,
|
||||
metadata,
|
||||
orgId: sanitizeString(data.orgId),
|
||||
actorType: sanitizeString(actorType),
|
||||
actor: sanitizeString(actor),
|
||||
actorId: sanitizeString(actorId),
|
||||
metadata: sanitizeString(metadata),
|
||||
action: data.action,
|
||||
resourceId: data.resourceId,
|
||||
reason: data.reason,
|
||||
location: data.location,
|
||||
originalRequestURL: body.originalRequestURL,
|
||||
scheme: body.scheme,
|
||||
host: body.host,
|
||||
path: body.path,
|
||||
method: body.method,
|
||||
ip: clientIp,
|
||||
location: sanitizeString(data.location),
|
||||
originalRequestURL: sanitizeString(body.originalRequestURL) ?? "",
|
||||
scheme: sanitizeString(body.scheme) ?? "",
|
||||
host: sanitizeString(body.host) ?? "",
|
||||
path: sanitizeString(body.path) ?? "",
|
||||
method: sanitizeString(body.method) ?? "",
|
||||
ip: sanitizeString(clientIp),
|
||||
tls: body.tls
|
||||
});
|
||||
|
||||
// Flush immediately if buffer is full, otherwise schedule a flush
|
||||
if (auditLogBuffer.length >= BATCH_SIZE) {
|
||||
// Fire and forget - don't block the caller
|
||||
|
||||
@@ -4,23 +4,23 @@ import {
|
||||
getResourceByDomain,
|
||||
getResourceRules,
|
||||
getRoleResourceAccess,
|
||||
getUserOrgRole,
|
||||
getUserResourceAccess,
|
||||
getOrgLoginPage,
|
||||
getUserSessionWithUser
|
||||
} from "@server/db/queries/verifySessionQueries";
|
||||
import { getUserOrgRoles } from "@server/lib/userOrgRoles";
|
||||
import {
|
||||
LoginPage,
|
||||
Org,
|
||||
Resource,
|
||||
ResourceHeaderAuth,
|
||||
ResourceHeaderAuthExtendedCompatibility,
|
||||
ResourcePassword,
|
||||
ResourcePincode,
|
||||
ResourceRule,
|
||||
resourceSessions
|
||||
ResourceRule
|
||||
} from "@server/db";
|
||||
import config from "@server/lib/config";
|
||||
import { isIpInCidr } from "@server/lib/ip";
|
||||
import { isIpInCidr, stripPortFromHost } from "@server/lib/ip";
|
||||
import { response } from "@server/lib/response";
|
||||
import logger from "@server/logger";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
@@ -30,16 +30,17 @@ import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { getCountryCodeForIp } from "@server/lib/geoip";
|
||||
import { getAsnForIp } from "@server/lib/asn";
|
||||
import { getOrgTierData } from "#dynamic/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { verifyPassword } from "@server/auth/password";
|
||||
import {
|
||||
checkOrgAccessPolicy,
|
||||
enforceResourceSessionLength
|
||||
} from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
import { logRequestAudit } from "./logRequestAudit";
|
||||
import cache from "@server/lib/cache";
|
||||
import { REGIONS } from "@server/db/regions";
|
||||
import { localCache } from "#dynamic/lib/cache";
|
||||
import { APP_VERSION } from "@server/lib/consts";
|
||||
import { isSubscribed } from "#dynamic/lib/isSubscribed";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
|
||||
const verifyResourceSessionSchema = z.object({
|
||||
sessions: z.record(z.string(), z.string()).optional(),
|
||||
@@ -51,7 +52,8 @@ const verifyResourceSessionSchema = z.object({
|
||||
path: z.string(),
|
||||
method: z.string(),
|
||||
tls: z.boolean(),
|
||||
requestIp: z.string().optional()
|
||||
requestIp: z.string().optional(),
|
||||
badgerVersion: z.string().optional()
|
||||
});
|
||||
|
||||
export type VerifyResourceSessionSchema = z.infer<
|
||||
@@ -67,8 +69,10 @@ type BasicUserData = {
|
||||
|
||||
export type VerifyUserResponse = {
|
||||
valid: boolean;
|
||||
headerAuthChallenged?: boolean;
|
||||
redirectUrl?: string;
|
||||
userData?: BasicUserData;
|
||||
pangolinVersion?: string;
|
||||
};
|
||||
|
||||
export async function verifyResourceSession(
|
||||
@@ -97,31 +101,15 @@ export async function verifyResourceSession(
|
||||
requestIp,
|
||||
path,
|
||||
headers,
|
||||
query
|
||||
query,
|
||||
badgerVersion
|
||||
} = parsedBody.data;
|
||||
|
||||
// Extract HTTP Basic Auth credentials if present
|
||||
const clientHeaderAuth = extractBasicAuth(headers);
|
||||
|
||||
const clientIp = requestIp
|
||||
? (() => {
|
||||
logger.debug("Request IP:", { requestIp });
|
||||
if (requestIp.startsWith("[") && requestIp.includes("]")) {
|
||||
// if brackets are found, extract the IPv6 address from between the brackets
|
||||
const ipv6Match = requestIp.match(/\[(.*?)\]/);
|
||||
if (ipv6Match) {
|
||||
return ipv6Match[1];
|
||||
}
|
||||
}
|
||||
|
||||
// ivp4
|
||||
// split at last colon
|
||||
const lastColonIndex = requestIp.lastIndexOf(":");
|
||||
if (lastColonIndex !== -1) {
|
||||
return requestIp.substring(0, lastColonIndex);
|
||||
}
|
||||
return requestIp;
|
||||
})()
|
||||
? stripPortFromHost(requestIp, badgerVersion)
|
||||
: undefined;
|
||||
|
||||
logger.debug("Client IP:", { clientIp });
|
||||
@@ -130,9 +118,7 @@ export async function verifyResourceSession(
|
||||
? await getCountryCodeFromIp(clientIp)
|
||||
: undefined;
|
||||
|
||||
const ipAsn = clientIp
|
||||
? await getAsnFromIp(clientIp)
|
||||
: undefined;
|
||||
const ipAsn = clientIp ? await getAsnFromIp(clientIp) : undefined;
|
||||
|
||||
let cleanHost = host;
|
||||
// if the host ends with :port, strip it
|
||||
@@ -148,9 +134,10 @@ export async function verifyResourceSession(
|
||||
pincode: ResourcePincode | null;
|
||||
password: ResourcePassword | null;
|
||||
headerAuth: ResourceHeaderAuth | null;
|
||||
headerAuthExtendedCompatibility: ResourceHeaderAuthExtendedCompatibility | null;
|
||||
org: Org;
|
||||
}
|
||||
| undefined = cache.get(resourceCacheKey);
|
||||
| undefined = localCache.get(resourceCacheKey);
|
||||
|
||||
if (!resourceData) {
|
||||
const result = await getResourceByDomain(cleanHost);
|
||||
@@ -174,10 +161,16 @@ export async function verifyResourceSession(
|
||||
}
|
||||
|
||||
resourceData = result;
|
||||
cache.set(resourceCacheKey, resourceData, 5);
|
||||
localCache.set(resourceCacheKey, resourceData, 5);
|
||||
}
|
||||
|
||||
const { resource, pincode, password, headerAuth } = resourceData;
|
||||
const {
|
||||
resource,
|
||||
pincode,
|
||||
password,
|
||||
headerAuth,
|
||||
headerAuthExtendedCompatibility
|
||||
} = resourceData;
|
||||
|
||||
if (!resource) {
|
||||
logger.debug(`Resource not found ${cleanHost}`);
|
||||
@@ -412,7 +405,7 @@ export async function verifyResourceSession(
|
||||
// check for HTTP Basic Auth header
|
||||
const clientHeaderAuthKey = `headerAuth:${clientHeaderAuth}`;
|
||||
if (headerAuth && clientHeaderAuth) {
|
||||
if (cache.get(clientHeaderAuthKey)) {
|
||||
if (localCache.get(clientHeaderAuthKey)) {
|
||||
logger.debug(
|
||||
"Resource allowed because header auth is valid (cached)"
|
||||
);
|
||||
@@ -435,7 +428,7 @@ export async function verifyResourceSession(
|
||||
headerAuth.headerAuthHash
|
||||
)
|
||||
) {
|
||||
cache.set(clientHeaderAuthKey, clientHeaderAuth, 5);
|
||||
localCache.set(clientHeaderAuthKey, clientHeaderAuth, 5);
|
||||
logger.debug("Resource allowed because header auth is valid");
|
||||
|
||||
logRequestAudit(
|
||||
@@ -457,7 +450,8 @@ export async function verifyResourceSession(
|
||||
!sso &&
|
||||
!pincode &&
|
||||
!password &&
|
||||
!resource.emailWhitelistEnabled
|
||||
!resource.emailWhitelistEnabled &&
|
||||
!headerAuthExtendedCompatibility?.extendedCompatibilityIsActivated
|
||||
) {
|
||||
logRequestAudit(
|
||||
{
|
||||
@@ -478,7 +472,8 @@ export async function verifyResourceSession(
|
||||
!sso &&
|
||||
!pincode &&
|
||||
!password &&
|
||||
!resource.emailWhitelistEnabled
|
||||
!resource.emailWhitelistEnabled &&
|
||||
!headerAuthExtendedCompatibility?.extendedCompatibilityIsActivated
|
||||
) {
|
||||
logRequestAudit(
|
||||
{
|
||||
@@ -525,7 +520,7 @@ export async function verifyResourceSession(
|
||||
|
||||
if (resourceSessionToken) {
|
||||
const sessionCacheKey = `session:${resourceSessionToken}`;
|
||||
let resourceSession: any = cache.get(sessionCacheKey);
|
||||
let resourceSession: any = localCache.get(sessionCacheKey);
|
||||
|
||||
if (!resourceSession) {
|
||||
const result = await validateResourceSessionToken(
|
||||
@@ -534,7 +529,7 @@ export async function verifyResourceSession(
|
||||
);
|
||||
|
||||
resourceSession = result?.resourceSession;
|
||||
cache.set(sessionCacheKey, resourceSession, 5);
|
||||
localCache.set(sessionCacheKey, resourceSession, 5);
|
||||
}
|
||||
|
||||
if (resourceSession?.isRequestToken) {
|
||||
@@ -564,7 +559,7 @@ export async function verifyResourceSession(
|
||||
}
|
||||
|
||||
if (resourceSession) {
|
||||
// only run this check if not SSO sesion; SSO session length is checked later
|
||||
// only run this check if not SSO session; SSO session length is checked later
|
||||
const accessPolicy = await enforceResourceSessionLength(
|
||||
resourceSession,
|
||||
resourceData.org
|
||||
@@ -667,7 +662,7 @@ export async function verifyResourceSession(
|
||||
}:${resource.resourceId}`;
|
||||
|
||||
let allowedUserData: BasicUserData | null | undefined =
|
||||
cache.get(userAccessCacheKey);
|
||||
localCache.get(userAccessCacheKey);
|
||||
|
||||
if (allowedUserData === undefined) {
|
||||
allowedUserData = await isUserAllowedToAccessResource(
|
||||
@@ -676,7 +671,7 @@ export async function verifyResourceSession(
|
||||
resourceData.org
|
||||
);
|
||||
|
||||
cache.set(userAccessCacheKey, allowedUserData, 5);
|
||||
localCache.set(userAccessCacheKey, allowedUserData, 5);
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -708,6 +703,15 @@ export async function verifyResourceSession(
|
||||
}
|
||||
}
|
||||
|
||||
// If headerAuthExtendedCompatibility is activated but no clientHeaderAuth provided, force client to challenge
|
||||
if (
|
||||
headerAuthExtendedCompatibility &&
|
||||
headerAuthExtendedCompatibility.extendedCompatibilityIsActivated &&
|
||||
!clientHeaderAuth
|
||||
) {
|
||||
return headerAuthChallenged(res, redirectPath, resource.orgId);
|
||||
}
|
||||
|
||||
logger.debug("No more auth to check, resource not allowed");
|
||||
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
@@ -793,8 +797,12 @@ async function notAllowed(
|
||||
) {
|
||||
let loginPage: LoginPage | null = null;
|
||||
if (orgId) {
|
||||
const { tier } = await getOrgTierData(orgId); // returns null in oss
|
||||
if (tier === TierId.STANDARD) {
|
||||
const subscribed = await isSubscribed(
|
||||
// this is fine because the org login page is only a saas feature
|
||||
orgId,
|
||||
tierMatrix.loginPageDomain
|
||||
);
|
||||
if (subscribed) {
|
||||
loginPage = await getOrgLoginPage(orgId);
|
||||
}
|
||||
}
|
||||
@@ -816,7 +824,7 @@ async function notAllowed(
|
||||
}
|
||||
|
||||
const data = {
|
||||
data: { valid: false, redirectUrl },
|
||||
data: { valid: false, redirectUrl, pangolinVersion: APP_VERSION },
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Access denied",
|
||||
@@ -830,8 +838,8 @@ function allowed(res: Response, userData?: BasicUserData) {
|
||||
const data = {
|
||||
data:
|
||||
userData !== undefined && userData !== null
|
||||
? { valid: true, ...userData }
|
||||
: { valid: true },
|
||||
? { valid: true, ...userData, pangolinVersion: APP_VERSION }
|
||||
: { valid: true, pangolinVersion: APP_VERSION },
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Access allowed",
|
||||
@@ -840,6 +848,54 @@ function allowed(res: Response, userData?: BasicUserData) {
|
||||
return response<VerifyUserResponse>(res, data);
|
||||
}
|
||||
|
||||
async function headerAuthChallenged(
|
||||
res: Response,
|
||||
redirectPath?: string,
|
||||
orgId?: string
|
||||
) {
|
||||
let loginPage: LoginPage | null = null;
|
||||
if (orgId) {
|
||||
const subscribed = await isSubscribed(
|
||||
orgId,
|
||||
tierMatrix.loginPageDomain
|
||||
); // this is fine because the org login page is only a saas feature
|
||||
if (subscribed) {
|
||||
loginPage = await getOrgLoginPage(orgId);
|
||||
}
|
||||
}
|
||||
|
||||
let redirectUrl: string | undefined = undefined;
|
||||
if (redirectPath) {
|
||||
let endpoint: string;
|
||||
|
||||
if (loginPage && loginPage.domainId && loginPage.fullDomain) {
|
||||
const secure = config
|
||||
.getRawConfig()
|
||||
.app.dashboard_url?.startsWith("https");
|
||||
const method = secure ? "https" : "http";
|
||||
endpoint = `${method}://${loginPage.fullDomain}`;
|
||||
} else {
|
||||
endpoint = config.getRawConfig().app.dashboard_url!;
|
||||
}
|
||||
redirectUrl = `${endpoint}${redirectPath}`;
|
||||
}
|
||||
|
||||
const data = {
|
||||
data: {
|
||||
headerAuthChallenged: true,
|
||||
valid: false,
|
||||
redirectUrl,
|
||||
pangolinVersion: APP_VERSION
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Access denied",
|
||||
status: HttpCode.OK
|
||||
};
|
||||
logger.debug(JSON.stringify(data));
|
||||
return response<VerifyUserResponse>(res, data);
|
||||
}
|
||||
|
||||
async function isUserAllowedToAccessResource(
|
||||
userSessionId: string,
|
||||
resource: Resource,
|
||||
@@ -864,9 +920,9 @@ async function isUserAllowedToAccessResource(
|
||||
return null;
|
||||
}
|
||||
|
||||
const userOrgRole = await getUserOrgRole(user.userId, resource.orgId);
|
||||
const userOrgRoles = await getUserOrgRoles(user.userId, resource.orgId);
|
||||
|
||||
if (!userOrgRole) {
|
||||
if (!userOrgRoles.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -884,15 +940,14 @@ async function isUserAllowedToAccessResource(
|
||||
|
||||
const roleResourceAccess = await getRoleResourceAccess(
|
||||
resource.resourceId,
|
||||
userOrgRole.roleId
|
||||
userOrgRoles.map((r) => r.roleId)
|
||||
);
|
||||
|
||||
if (roleResourceAccess) {
|
||||
if (roleResourceAccess && roleResourceAccess.length > 0) {
|
||||
return {
|
||||
username: user.username,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
role: user.role
|
||||
role: userOrgRoles.map((r) => r.roleName).join(", ")
|
||||
};
|
||||
}
|
||||
|
||||
@@ -906,7 +961,7 @@ async function isUserAllowedToAccessResource(
|
||||
username: user.username,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
role: user.role
|
||||
role: userOrgRoles.map((r) => r.roleName).join(", ")
|
||||
};
|
||||
}
|
||||
|
||||
@@ -922,11 +977,11 @@ async function checkRules(
|
||||
): Promise<"ACCEPT" | "DROP" | "PASS" | undefined> {
|
||||
const ruleCacheKey = `rules:${resourceId}`;
|
||||
|
||||
let rules: ResourceRule[] | undefined = cache.get(ruleCacheKey);
|
||||
let rules: ResourceRule[] | undefined = localCache.get(ruleCacheKey);
|
||||
|
||||
if (!rules) {
|
||||
rules = await getResourceRules(resourceId);
|
||||
cache.set(ruleCacheKey, rules, 5);
|
||||
localCache.set(ruleCacheKey, rules, 5);
|
||||
}
|
||||
|
||||
if (rules.length === 0) {
|
||||
@@ -991,14 +1046,29 @@ export function isPathAllowed(pattern: string, path: string): boolean {
|
||||
logger.debug(`Normalized pattern parts: [${patternParts.join(", ")}]`);
|
||||
logger.debug(`Normalized path parts: [${pathParts.join(", ")}]`);
|
||||
|
||||
// Maximum recursion depth to prevent stack overflow and memory issues
|
||||
const MAX_RECURSION_DEPTH = 100;
|
||||
|
||||
// Recursive function to try different wildcard matches
|
||||
function matchSegments(patternIndex: number, pathIndex: number): boolean {
|
||||
const indent = " ".repeat(pathIndex); // Indent based on recursion depth
|
||||
function matchSegments(
|
||||
patternIndex: number,
|
||||
pathIndex: number,
|
||||
depth: number = 0
|
||||
): boolean {
|
||||
// Check recursion depth limit
|
||||
if (depth > MAX_RECURSION_DEPTH) {
|
||||
logger.warn(
|
||||
`Path matching exceeded maximum recursion depth (${MAX_RECURSION_DEPTH}) for pattern "${pattern}" and path "${path}"`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
const indent = " ".repeat(depth); // Indent based on recursion depth
|
||||
const currentPatternPart = patternParts[patternIndex];
|
||||
const currentPathPart = pathParts[pathIndex];
|
||||
|
||||
logger.debug(
|
||||
`${indent}Checking patternIndex=${patternIndex} (${currentPatternPart || "END"}) vs pathIndex=${pathIndex} (${currentPathPart || "END"})`
|
||||
`${indent}Checking patternIndex=${patternIndex} (${currentPatternPart || "END"}) vs pathIndex=${pathIndex} (${currentPathPart || "END"}) [depth=${depth}]`
|
||||
);
|
||||
|
||||
// If we've consumed all pattern parts, we should have consumed all path parts
|
||||
@@ -1031,7 +1101,7 @@ export function isPathAllowed(pattern: string, path: string): boolean {
|
||||
logger.debug(
|
||||
`${indent}Trying to skip wildcard (consume 0 segments)`
|
||||
);
|
||||
if (matchSegments(patternIndex + 1, pathIndex)) {
|
||||
if (matchSegments(patternIndex + 1, pathIndex, depth + 1)) {
|
||||
logger.debug(
|
||||
`${indent}Successfully matched by skipping wildcard`
|
||||
);
|
||||
@@ -1042,7 +1112,7 @@ export function isPathAllowed(pattern: string, path: string): boolean {
|
||||
logger.debug(
|
||||
`${indent}Trying to consume segment "${currentPathPart}" for wildcard`
|
||||
);
|
||||
if (matchSegments(patternIndex, pathIndex + 1)) {
|
||||
if (matchSegments(patternIndex, pathIndex + 1, depth + 1)) {
|
||||
logger.debug(
|
||||
`${indent}Successfully matched by consuming segment for wildcard`
|
||||
);
|
||||
@@ -1070,7 +1140,11 @@ export function isPathAllowed(pattern: string, path: string): boolean {
|
||||
logger.debug(
|
||||
`${indent}Segment with wildcard matches: "${currentPatternPart}" matches "${currentPathPart}"`
|
||||
);
|
||||
return matchSegments(patternIndex + 1, pathIndex + 1);
|
||||
return matchSegments(
|
||||
patternIndex + 1,
|
||||
pathIndex + 1,
|
||||
depth + 1
|
||||
);
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
@@ -1091,10 +1165,10 @@ export function isPathAllowed(pattern: string, path: string): boolean {
|
||||
`${indent}Segments match: "${currentPatternPart}" = "${currentPathPart}"`
|
||||
);
|
||||
// Move to next segments in both pattern and path
|
||||
return matchSegments(patternIndex + 1, pathIndex + 1);
|
||||
return matchSegments(patternIndex + 1, pathIndex + 1, depth + 1);
|
||||
}
|
||||
|
||||
const result = matchSegments(0, 0);
|
||||
const result = matchSegments(0, 0, 0);
|
||||
logger.debug(`Final result: ${result}`);
|
||||
return result;
|
||||
}
|
||||
@@ -1182,13 +1256,13 @@ export async function isIpInRegion(
|
||||
async function getAsnFromIp(ip: string): Promise<number | undefined> {
|
||||
const asnCacheKey = `asn:${ip}`;
|
||||
|
||||
let cachedAsn: number | undefined = cache.get(asnCacheKey);
|
||||
let cachedAsn: number | undefined = localCache.get(asnCacheKey);
|
||||
|
||||
if (!cachedAsn) {
|
||||
cachedAsn = await getAsnForIp(ip); // do it locally
|
||||
// Cache for longer since IP ASN doesn't change frequently
|
||||
if (cachedAsn) {
|
||||
cache.set(asnCacheKey, cachedAsn, 300); // 5 minutes
|
||||
localCache.set(asnCacheKey, cachedAsn, 300); // 5 minutes
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1198,12 +1272,15 @@ async function getAsnFromIp(ip: string): Promise<number | undefined> {
|
||||
async function getCountryCodeFromIp(ip: string): Promise<string | undefined> {
|
||||
const geoIpCacheKey = `geoip:${ip}`;
|
||||
|
||||
let cachedCountryCode: string | undefined = cache.get(geoIpCacheKey);
|
||||
let cachedCountryCode: string | undefined = localCache.get(geoIpCacheKey);
|
||||
|
||||
if (!cachedCountryCode) {
|
||||
cachedCountryCode = await getCountryCodeForIp(ip); // do it locally
|
||||
// Cache for longer since IP geolocation doesn't change frequently
|
||||
cache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes
|
||||
// Only cache successful lookups to avoid filling cache with undefined values
|
||||
if (cachedCountryCode) {
|
||||
// Cache for longer since IP geolocation doesn't change frequently
|
||||
localCache.set(geoIpCacheKey, cachedCountryCode, 300); // 5 minutes
|
||||
}
|
||||
}
|
||||
|
||||
return cachedCountryCode;
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { Limit, Subscription, SubscriptionItem, Usage } from "@server/db";
|
||||
|
||||
export type GetOrgSubscriptionResponse = {
|
||||
subscription: Subscription | null;
|
||||
items: SubscriptionItem[];
|
||||
subscriptions: Array<{ subscription: Subscription; items: SubscriptionItem[] }>;
|
||||
/** When build === saas, true if org has exceeded plan limits (sites, users, etc.) */
|
||||
limitsExceeded?: boolean;
|
||||
};
|
||||
|
||||
export type GetOrgUsageResponse = {
|
||||
|
||||
@@ -20,7 +20,7 @@ registry.registerPath({
|
||||
method: "put",
|
||||
path: "/org/{orgId}/blueprint",
|
||||
description: "Apply a base64 encoded JSON blueprint to an organization",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.Blueprint],
|
||||
tags: [OpenAPITags.Blueprint],
|
||||
request: {
|
||||
params: applyBlueprintParamsSchema,
|
||||
body: {
|
||||
|
||||
@@ -26,7 +26,8 @@ const applyBlueprintSchema = z
|
||||
message: `Invalid YAML: ${error instanceof Error ? error.message : "Unknown error"}`
|
||||
});
|
||||
}
|
||||
})
|
||||
}),
|
||||
source: z.enum(["API", "UI", "CLI"]).optional()
|
||||
})
|
||||
.strict();
|
||||
|
||||
@@ -42,7 +43,7 @@ registry.registerPath({
|
||||
method: "put",
|
||||
path: "/org/{orgId}/blueprint",
|
||||
description: "Create and apply a YAML blueprint to an organization",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.Blueprint],
|
||||
tags: [OpenAPITags.Blueprint],
|
||||
request: {
|
||||
params: applyBlueprintParamsSchema,
|
||||
body: {
|
||||
@@ -84,7 +85,7 @@ export async function applyYAMLBlueprint(
|
||||
);
|
||||
}
|
||||
|
||||
const { blueprint: contents, name } = parsedBody.data;
|
||||
const { blueprint: contents, name, source = "UI" } = parsedBody.data;
|
||||
|
||||
logger.debug(`Received blueprint:`, contents);
|
||||
|
||||
@@ -107,7 +108,7 @@ export async function applyYAMLBlueprint(
|
||||
blueprint = await applyBlueprint({
|
||||
orgId,
|
||||
name,
|
||||
source: "UI",
|
||||
source,
|
||||
configData: parsedConfig
|
||||
});
|
||||
} catch (err) {
|
||||
|
||||
@@ -53,7 +53,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/blueprint/{blueprintId}",
|
||||
description: "Get a blueprint by its blueprint ID.",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.Blueprint],
|
||||
tags: [OpenAPITags.Blueprint],
|
||||
request: {
|
||||
params: getBlueprintSchema
|
||||
},
|
||||
|
||||
@@ -67,7 +67,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/blueprints",
|
||||
description: "List all blueprints for a organization.",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.Blueprint],
|
||||
tags: [OpenAPITags.Blueprint],
|
||||
request: {
|
||||
params: z.object({
|
||||
orgId: z.string()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { Blueprint } from "@server/db";
|
||||
|
||||
export type BlueprintSource = "API" | "UI" | "NEWT";
|
||||
export type BlueprintSource = "API" | "UI" | "NEWT" | "CLI";
|
||||
|
||||
export type BlueprintData = Omit<Blueprint, "source"> & {
|
||||
source: BlueprintSource;
|
||||
|
||||
@@ -6,8 +6,8 @@ export type GetCertificateResponse = {
|
||||
status: string; // pending, requested, valid, expired, failed
|
||||
expiresAt: string | null;
|
||||
lastRenewalAttempt: Date | null;
|
||||
createdAt: string;
|
||||
updatedAt: string;
|
||||
createdAt: number;
|
||||
updatedAt: number;
|
||||
errorMessage?: string | null;
|
||||
renewalCount: number;
|
||||
};
|
||||
|
||||
95
server/routers/client/archiveClient.ts
Normal file
95
server/routers/client/archiveClient.ts
Normal file
@@ -0,0 +1,95 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { clients } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const archiveClientSchema = z.strictObject({
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/client/{clientId}/archive",
|
||||
description: "Archive a client by its client ID.",
|
||||
tags: [OpenAPITags.Client],
|
||||
request: {
|
||||
params: archiveClientSchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function archiveClient(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = archiveClientSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { clientId } = parsedParams.data;
|
||||
|
||||
// Check if client exists
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Client with ID ${clientId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (client.archived) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Client with ID ${clientId} is already archived`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
// Archive the client
|
||||
await trx
|
||||
.update(clients)
|
||||
.set({ archived: true })
|
||||
.where(eq(clients.clientId, clientId));
|
||||
});
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Client archived successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to archive client"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
102
server/routers/client/blockClient.ts
Normal file
102
server/routers/client/blockClient.ts
Normal file
@@ -0,0 +1,102 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { clients } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { sendTerminateClient } from "./terminate";
|
||||
import { OlmErrorCodes } from "../olm/error";
|
||||
|
||||
const blockClientSchema = z.strictObject({
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/client/{clientId}/block",
|
||||
description: "Block a client by its client ID.",
|
||||
tags: [OpenAPITags.Client],
|
||||
request: {
|
||||
params: blockClientSchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function blockClient(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = blockClientSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { clientId } = parsedParams.data;
|
||||
|
||||
// Check if client exists
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Client with ID ${clientId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (client.blocked) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Client with ID ${clientId} is already blocked`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
// Block the client
|
||||
await trx
|
||||
.update(clients)
|
||||
.set({ blocked: true, approvalState: "denied" })
|
||||
.where(eq(clients.clientId, clientId));
|
||||
|
||||
// Send terminate signal if there's an associated OLM and it's connected
|
||||
if (client.olmId && client.online) {
|
||||
await sendTerminateClient(client.clientId, OlmErrorCodes.TERMINATED_BLOCKED, client.olmId);
|
||||
}
|
||||
});
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Client blocked successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to block client"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,7 @@ import { generateId } from "@server/auth/sessions/app";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
|
||||
import { getUniqueClientName } from "@server/db/names";
|
||||
import { build } from "@server/build";
|
||||
|
||||
const createClientParamsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
@@ -47,7 +48,7 @@ registry.registerPath({
|
||||
method: "put",
|
||||
path: "/org/{orgId}/client",
|
||||
description: "Create a new client for an organization.",
|
||||
tags: [OpenAPITags.Client, OpenAPITags.Org],
|
||||
tags: [OpenAPITags.Client],
|
||||
request: {
|
||||
params: createClientParamsSchema,
|
||||
body: {
|
||||
@@ -91,7 +92,7 @@ export async function createClient(
|
||||
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
if (req.user && !req.userOrgRoleId) {
|
||||
if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) {
|
||||
return next(
|
||||
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
|
||||
);
|
||||
@@ -101,7 +102,7 @@ export async function createClient(
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Invalid subnet format. Please provide a valid CIDR notation."
|
||||
"Invalid subnet format. Please provide a valid IP."
|
||||
)
|
||||
);
|
||||
}
|
||||
@@ -195,6 +196,12 @@ export async function createClient(
|
||||
const randomExitNode =
|
||||
exitNodesList[Math.floor(Math.random() * exitNodesList.length)];
|
||||
|
||||
if (!randomExitNode) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, `No exit nodes available. ${build == "saas" ? "Please contact support." : "You need to install gerbil to use the clients."}`)
|
||||
);
|
||||
}
|
||||
|
||||
const [adminRole] = await trx
|
||||
.select()
|
||||
.from(roles)
|
||||
@@ -227,7 +234,7 @@ export async function createClient(
|
||||
clientId: newClient.clientId
|
||||
});
|
||||
|
||||
if (req.user && req.userOrgRoleId != adminRole.roleId) {
|
||||
if (req.user && !req.userOrgRoleIds?.includes(adminRole.roleId)) {
|
||||
// make sure the user can access the client
|
||||
trx.insert(userClients).values({
|
||||
userId: req.user.userId,
|
||||
|
||||
@@ -49,7 +49,7 @@ registry.registerPath({
|
||||
path: "/org/{orgId}/user/{userId}/client",
|
||||
description:
|
||||
"Create a new client for a user and associate it with an existing olm.",
|
||||
tags: [OpenAPITags.Client, OpenAPITags.Org, OpenAPITags.User],
|
||||
tags: [OpenAPITags.Client],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
body: {
|
||||
|
||||
@@ -11,6 +11,7 @@ import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
|
||||
import { sendTerminateClient } from "./terminate";
|
||||
import { OlmErrorCodes } from "../olm/error";
|
||||
|
||||
const deleteClientSchema = z.strictObject({
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
@@ -60,11 +61,12 @@ export async function deleteClient(
|
||||
);
|
||||
}
|
||||
|
||||
// Only allow deletion of machine clients (clients without userId)
|
||||
if (client.userId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Cannot delete a user client with this endpoint`
|
||||
`Cannot delete a user client. User clients must be archived instead.`
|
||||
)
|
||||
);
|
||||
}
|
||||
@@ -90,7 +92,7 @@ export async function deleteClient(
|
||||
await rebuildClientAssociationsFromClient(deletedClient, trx);
|
||||
|
||||
if (olm) {
|
||||
await sendTerminateClient(deletedClient.clientId, olm.olmId); // the olmId needs to be provided because it cant look it up after deletion
|
||||
await sendTerminateClient(deletedClient.clientId, OlmErrorCodes.TERMINATED_DELETED, olm.olmId); // the olmId needs to be provided because it cant look it up after deletion
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db, olms } from "@server/db";
|
||||
import { clients } from "@server/db";
|
||||
import { db, olms, users } from "@server/db";
|
||||
import { clients, currentFingerprint } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
@@ -10,6 +10,10 @@ import logger from "@server/logger";
|
||||
import stoi from "@server/lib/stoi";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { getUserDeviceName } from "@server/db/names";
|
||||
import { build } from "@server/build";
|
||||
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
|
||||
const getClientSchema = z.strictObject({
|
||||
clientId: z
|
||||
@@ -29,6 +33,11 @@ async function query(clientId?: number, niceId?: string, orgId?: string) {
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.leftJoin(olms, eq(clients.clientId, olms.clientId))
|
||||
.leftJoin(
|
||||
currentFingerprint,
|
||||
eq(olms.olmId, currentFingerprint.olmId)
|
||||
)
|
||||
.leftJoin(users, eq(clients.userId, users.userId))
|
||||
.limit(1);
|
||||
return res;
|
||||
} else if (niceId && orgId) {
|
||||
@@ -36,16 +45,197 @@ async function query(clientId?: number, niceId?: string, orgId?: string) {
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(and(eq(clients.niceId, niceId), eq(clients.orgId, orgId)))
|
||||
.leftJoin(olms, eq(olms.clientId, olms.clientId))
|
||||
.leftJoin(olms, eq(clients.clientId, olms.clientId))
|
||||
.leftJoin(
|
||||
currentFingerprint,
|
||||
eq(olms.olmId, currentFingerprint.olmId)
|
||||
)
|
||||
.leftJoin(users, eq(clients.userId, users.userId))
|
||||
.limit(1);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
type PostureData = {
|
||||
biometricsEnabled?: boolean | null | "-";
|
||||
diskEncrypted?: boolean | null | "-";
|
||||
firewallEnabled?: boolean | null | "-";
|
||||
autoUpdatesEnabled?: boolean | null | "-";
|
||||
tpmAvailable?: boolean | null | "-";
|
||||
windowsAntivirusEnabled?: boolean | null | "-";
|
||||
macosSipEnabled?: boolean | null | "-";
|
||||
macosGatekeeperEnabled?: boolean | null | "-";
|
||||
macosFirewallStealthMode?: boolean | null | "-";
|
||||
linuxAppArmorEnabled?: boolean | null | "-";
|
||||
linuxSELinuxEnabled?: boolean | null | "-";
|
||||
};
|
||||
|
||||
function maskPostureDataWithPlaceholder(posture: PostureData): PostureData {
|
||||
const masked: PostureData = {};
|
||||
for (const key of Object.keys(posture) as (keyof PostureData)[]) {
|
||||
if (posture[key] !== undefined && posture[key] !== null) {
|
||||
(masked as Record<keyof PostureData, "-">)[key] = "-";
|
||||
}
|
||||
}
|
||||
return masked;
|
||||
}
|
||||
|
||||
function getPlatformPostureData(
|
||||
platform: string | null | undefined,
|
||||
fingerprint: typeof currentFingerprint.$inferSelect | null
|
||||
): PostureData | null {
|
||||
if (!fingerprint) return null;
|
||||
|
||||
const normalizedPlatform = platform?.toLowerCase() || "unknown";
|
||||
const posture: PostureData = {};
|
||||
|
||||
// Windows: Hard drive encryption, Firewall, Auto updates, TPM availability, Windows Antivirus status
|
||||
if (normalizedPlatform === "windows") {
|
||||
if (
|
||||
fingerprint.diskEncrypted !== null &&
|
||||
fingerprint.diskEncrypted !== undefined
|
||||
) {
|
||||
posture.diskEncrypted = fingerprint.diskEncrypted;
|
||||
}
|
||||
if (
|
||||
fingerprint.firewallEnabled !== null &&
|
||||
fingerprint.firewallEnabled !== undefined
|
||||
) {
|
||||
posture.firewallEnabled = fingerprint.firewallEnabled;
|
||||
}
|
||||
if (
|
||||
fingerprint.tpmAvailable !== null &&
|
||||
fingerprint.tpmAvailable !== undefined
|
||||
) {
|
||||
posture.tpmAvailable = fingerprint.tpmAvailable;
|
||||
}
|
||||
if (
|
||||
fingerprint.windowsAntivirusEnabled !== null &&
|
||||
fingerprint.windowsAntivirusEnabled !== undefined
|
||||
) {
|
||||
posture.windowsAntivirusEnabled =
|
||||
fingerprint.windowsAntivirusEnabled;
|
||||
}
|
||||
}
|
||||
// macOS: Hard drive encryption, Biometric configuration, Firewall, System Integrity Protection (SIP), Gatekeeper, Firewall stealth mode
|
||||
else if (normalizedPlatform === "macos") {
|
||||
if (
|
||||
fingerprint.diskEncrypted !== null &&
|
||||
fingerprint.diskEncrypted !== undefined
|
||||
) {
|
||||
posture.diskEncrypted = fingerprint.diskEncrypted;
|
||||
}
|
||||
if (
|
||||
fingerprint.biometricsEnabled !== null &&
|
||||
fingerprint.biometricsEnabled !== undefined
|
||||
) {
|
||||
posture.biometricsEnabled = fingerprint.biometricsEnabled;
|
||||
}
|
||||
if (
|
||||
fingerprint.firewallEnabled !== null &&
|
||||
fingerprint.firewallEnabled !== undefined
|
||||
) {
|
||||
posture.firewallEnabled = fingerprint.firewallEnabled;
|
||||
}
|
||||
if (
|
||||
fingerprint.macosSipEnabled !== null &&
|
||||
fingerprint.macosSipEnabled !== undefined
|
||||
) {
|
||||
posture.macosSipEnabled = fingerprint.macosSipEnabled;
|
||||
}
|
||||
if (
|
||||
fingerprint.macosGatekeeperEnabled !== null &&
|
||||
fingerprint.macosGatekeeperEnabled !== undefined
|
||||
) {
|
||||
posture.macosGatekeeperEnabled = fingerprint.macosGatekeeperEnabled;
|
||||
}
|
||||
if (
|
||||
fingerprint.macosFirewallStealthMode !== null &&
|
||||
fingerprint.macosFirewallStealthMode !== undefined
|
||||
) {
|
||||
posture.macosFirewallStealthMode =
|
||||
fingerprint.macosFirewallStealthMode;
|
||||
}
|
||||
if (
|
||||
fingerprint.autoUpdatesEnabled !== null &&
|
||||
fingerprint.autoUpdatesEnabled !== undefined
|
||||
) {
|
||||
posture.autoUpdatesEnabled = fingerprint.autoUpdatesEnabled;
|
||||
}
|
||||
}
|
||||
// Linux: Hard drive encryption, Firewall, AppArmor, SELinux, TPM availability
|
||||
else if (normalizedPlatform === "linux") {
|
||||
if (
|
||||
fingerprint.diskEncrypted !== null &&
|
||||
fingerprint.diskEncrypted !== undefined
|
||||
) {
|
||||
posture.diskEncrypted = fingerprint.diskEncrypted;
|
||||
}
|
||||
if (
|
||||
fingerprint.firewallEnabled !== null &&
|
||||
fingerprint.firewallEnabled !== undefined
|
||||
) {
|
||||
posture.firewallEnabled = fingerprint.firewallEnabled;
|
||||
}
|
||||
if (
|
||||
fingerprint.linuxAppArmorEnabled !== null &&
|
||||
fingerprint.linuxAppArmorEnabled !== undefined
|
||||
) {
|
||||
posture.linuxAppArmorEnabled = fingerprint.linuxAppArmorEnabled;
|
||||
}
|
||||
if (
|
||||
fingerprint.linuxSELinuxEnabled !== null &&
|
||||
fingerprint.linuxSELinuxEnabled !== undefined
|
||||
) {
|
||||
posture.linuxSELinuxEnabled = fingerprint.linuxSELinuxEnabled;
|
||||
}
|
||||
if (
|
||||
fingerprint.tpmAvailable !== null &&
|
||||
fingerprint.tpmAvailable !== undefined
|
||||
) {
|
||||
posture.tpmAvailable = fingerprint.tpmAvailable;
|
||||
}
|
||||
}
|
||||
// iOS: Biometric configuration
|
||||
else if (normalizedPlatform === "ios") {
|
||||
// none supported yet
|
||||
}
|
||||
// Android: Screen lock, Biometric configuration, Hard drive encryption
|
||||
else if (normalizedPlatform === "android") {
|
||||
if (
|
||||
fingerprint.diskEncrypted !== null &&
|
||||
fingerprint.diskEncrypted !== undefined
|
||||
) {
|
||||
posture.diskEncrypted = fingerprint.diskEncrypted;
|
||||
}
|
||||
}
|
||||
|
||||
// Only return if we have at least one posture field
|
||||
return Object.keys(posture).length > 0 ? posture : null;
|
||||
}
|
||||
|
||||
export type GetClientResponse = NonNullable<
|
||||
Awaited<ReturnType<typeof query>>
|
||||
>["clients"] & {
|
||||
olmId: string | null;
|
||||
agent: string | null;
|
||||
olmVersion: string | null;
|
||||
userEmail: string | null;
|
||||
userName: string | null;
|
||||
userUsername: string | null;
|
||||
fingerprint: {
|
||||
username: string | null;
|
||||
hostname: string | null;
|
||||
platform: string | null;
|
||||
osVersion: string | null;
|
||||
kernelVersion: string | null;
|
||||
arch: string | null;
|
||||
deviceModel: string | null;
|
||||
serialNumber: string | null;
|
||||
firstSeen: number | null;
|
||||
lastSeen: number | null;
|
||||
} | null;
|
||||
posture: PostureData | null;
|
||||
};
|
||||
|
||||
registry.registerPath({
|
||||
@@ -53,7 +243,7 @@ registry.registerPath({
|
||||
path: "/org/{orgId}/client/{niceId}",
|
||||
description:
|
||||
"Get a client by orgId and niceId. NiceId is a readable ID for the site and unique on a per org basis.",
|
||||
tags: [OpenAPITags.Org, OpenAPITags.Site],
|
||||
tags: [OpenAPITags.Site],
|
||||
request: {
|
||||
params: z.object({
|
||||
orgId: z.string(),
|
||||
@@ -105,9 +295,59 @@ export async function getClient(
|
||||
);
|
||||
}
|
||||
|
||||
const isUserDevice = client.user !== null && client.user !== undefined;
|
||||
|
||||
// Replace name with device name if OLM exists
|
||||
let clientName = client.clients.name;
|
||||
if (client.olms && isUserDevice) {
|
||||
const model = client.currentFingerprint?.deviceModel || null;
|
||||
clientName = getUserDeviceName(model, client.clients.name);
|
||||
}
|
||||
|
||||
// Build fingerprint data if available
|
||||
const fingerprintData = client.currentFingerprint
|
||||
? {
|
||||
username: client.currentFingerprint.username || null,
|
||||
hostname: client.currentFingerprint.hostname || null,
|
||||
platform: client.currentFingerprint.platform || null,
|
||||
osVersion: client.currentFingerprint.osVersion || null,
|
||||
kernelVersion:
|
||||
client.currentFingerprint.kernelVersion || null,
|
||||
arch: client.currentFingerprint.arch || null,
|
||||
deviceModel: client.currentFingerprint.deviceModel || null,
|
||||
serialNumber: client.currentFingerprint.serialNumber || null,
|
||||
firstSeen: client.currentFingerprint.firstSeen || null,
|
||||
lastSeen: client.currentFingerprint.lastSeen || null
|
||||
}
|
||||
: null;
|
||||
|
||||
// Build posture data if available (platform-specific)
|
||||
// Licensed: real values; not licensed: same keys but values set to "-"
|
||||
const rawPosture = getPlatformPostureData(
|
||||
client.currentFingerprint?.platform || null,
|
||||
client.currentFingerprint
|
||||
);
|
||||
const isOrgLicensed = await isLicensedOrSubscribed(
|
||||
client.clients.orgId,
|
||||
tierMatrix.devicePosture
|
||||
);
|
||||
const postureData: PostureData | null = rawPosture
|
||||
? isOrgLicensed
|
||||
? rawPosture
|
||||
: maskPostureDataWithPlaceholder(rawPosture)
|
||||
: null;
|
||||
|
||||
const data: GetClientResponse = {
|
||||
...client.clients,
|
||||
olmId: client.olms ? client.olms.olmId : null
|
||||
name: clientName,
|
||||
olmId: client.olms ? client.olms.olmId : null,
|
||||
agent: client.olms?.agent || null,
|
||||
olmVersion: client.olms?.version || null,
|
||||
userEmail: client.user?.email ?? null,
|
||||
userName: client.user?.name ?? null,
|
||||
userUsername: client.user?.username ?? null,
|
||||
fingerprint: fingerprintData,
|
||||
posture: postureData
|
||||
};
|
||||
|
||||
return response<GetClientResponse>(res, {
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
export * from "./pickClientDefaults";
|
||||
export * from "./createClient";
|
||||
export * from "./deleteClient";
|
||||
export * from "./archiveClient";
|
||||
export * from "./unarchiveClient";
|
||||
export * from "./blockClient";
|
||||
export * from "./unblockClient";
|
||||
export * from "./listClients";
|
||||
export * from "./listUserDevices";
|
||||
export * from "./updateClient";
|
||||
export * from "./getClient";
|
||||
export * from "./createUserClient";
|
||||
|
||||
@@ -1,32 +1,38 @@
|
||||
import { db, olms, users } from "@server/db";
|
||||
import {
|
||||
clients,
|
||||
clientSitesAssociationsCache,
|
||||
currentFingerprint,
|
||||
db,
|
||||
olms,
|
||||
orgs,
|
||||
roleClients,
|
||||
sites,
|
||||
userClients,
|
||||
clientSitesAssociationsCache
|
||||
users
|
||||
} from "@server/db";
|
||||
import logger from "@server/logger";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import response from "@server/lib/response";
|
||||
import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import type { PaginatedResponse } from "@server/types/Pagination";
|
||||
import {
|
||||
and,
|
||||
count,
|
||||
asc,
|
||||
desc,
|
||||
eq,
|
||||
inArray,
|
||||
isNotNull,
|
||||
isNull,
|
||||
like,
|
||||
or,
|
||||
sql
|
||||
sql,
|
||||
type SQL
|
||||
} from "drizzle-orm";
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import NodeCache from "node-cache";
|
||||
import semver from "semver";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
|
||||
const olmVersionCache = new NodeCache({ stdTTL: 3600 });
|
||||
|
||||
@@ -56,15 +62,15 @@ async function getLatestOlmVersion(): Promise<string | null> {
|
||||
return null;
|
||||
}
|
||||
|
||||
const tags = await response.json();
|
||||
let tags = await response.json();
|
||||
if (!Array.isArray(tags) || tags.length === 0) {
|
||||
logger.warn("No tags found for Olm repository");
|
||||
return null;
|
||||
}
|
||||
|
||||
tags = tags.filter((version) => !version.name.includes("rc"));
|
||||
const latestVersion = tags[0].name;
|
||||
|
||||
olmVersionCache.set("latestOlmVersion", latestVersion);
|
||||
olmVersionCache.set("latestOlmVersion", latestVersion, 3600);
|
||||
|
||||
return latestVersion;
|
||||
} catch (error: any) {
|
||||
@@ -87,38 +93,86 @@ const listClientsParamsSchema = z.strictObject({
|
||||
});
|
||||
|
||||
const listClientsSchema = z.object({
|
||||
limit: z
|
||||
.string()
|
||||
pageSize: z.coerce
|
||||
.number<string>() // for prettier formatting
|
||||
.int()
|
||||
.positive()
|
||||
.optional()
|
||||
.default("1000")
|
||||
.transform(Number)
|
||||
.pipe(z.int().positive()),
|
||||
offset: z
|
||||
.string()
|
||||
.catch(20)
|
||||
.default(20)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 20,
|
||||
description: "Number of items per page"
|
||||
}),
|
||||
page: z.coerce
|
||||
.number<string>() // for prettier formatting
|
||||
.int()
|
||||
.min(0)
|
||||
.optional()
|
||||
.default("0")
|
||||
.transform(Number)
|
||||
.pipe(z.int().nonnegative()),
|
||||
filter: z.enum(["user", "machine"]).optional()
|
||||
.catch(1)
|
||||
.default(1)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 1,
|
||||
description: "Page number to retrieve"
|
||||
}),
|
||||
query: z.string().optional(),
|
||||
sort_by: z
|
||||
.enum(["name", "megabytesIn", "megabytesOut"])
|
||||
.optional()
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["name", "megabytesIn", "megabytesOut"],
|
||||
description: "Field to sort by"
|
||||
}),
|
||||
order: z
|
||||
.enum(["asc", "desc"])
|
||||
.optional()
|
||||
.default("asc")
|
||||
.catch("asc")
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["asc", "desc"],
|
||||
default: "asc",
|
||||
description: "Sort order"
|
||||
}),
|
||||
online: z
|
||||
.enum(["true", "false"])
|
||||
.transform((v) => v === "true")
|
||||
.optional()
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "boolean",
|
||||
description: "Filter by online status"
|
||||
}),
|
||||
status: z.preprocess(
|
||||
(val: string | undefined) => {
|
||||
if (val) {
|
||||
return val.split(","); // the search query array is an array joined by commas
|
||||
}
|
||||
return undefined;
|
||||
},
|
||||
z
|
||||
.array(z.enum(["active", "blocked", "archived"]))
|
||||
.optional()
|
||||
.default(["active"])
|
||||
.catch(["active"])
|
||||
.openapi({
|
||||
type: "array",
|
||||
items: {
|
||||
type: "string",
|
||||
enum: ["active", "blocked", "archived"]
|
||||
},
|
||||
default: ["active"],
|
||||
description:
|
||||
"Filter by client status. Can be a comma-separated list of values. Defaults to 'active'."
|
||||
})
|
||||
)
|
||||
});
|
||||
|
||||
function queryClients(
|
||||
orgId: string,
|
||||
accessibleClientIds: number[],
|
||||
filter?: "user" | "machine"
|
||||
) {
|
||||
const conditions = [
|
||||
inArray(clients.clientId, accessibleClientIds),
|
||||
eq(clients.orgId, orgId)
|
||||
];
|
||||
|
||||
// Add filter condition based on filter type
|
||||
if (filter === "user") {
|
||||
conditions.push(isNotNull(clients.userId));
|
||||
} else if (filter === "machine") {
|
||||
conditions.push(isNull(clients.userId));
|
||||
}
|
||||
|
||||
function queryClientsBase() {
|
||||
return db
|
||||
.select({
|
||||
clientId: clients.clientId,
|
||||
@@ -136,13 +190,17 @@ function queryClients(
|
||||
username: users.username,
|
||||
userEmail: users.email,
|
||||
niceId: clients.niceId,
|
||||
agent: olms.agent
|
||||
agent: olms.agent,
|
||||
approvalState: clients.approvalState,
|
||||
olmArchived: olms.archived,
|
||||
archived: clients.archived,
|
||||
blocked: clients.blocked
|
||||
})
|
||||
.from(clients)
|
||||
.leftJoin(orgs, eq(clients.orgId, orgs.orgId))
|
||||
.leftJoin(olms, eq(clients.clientId, olms.clientId))
|
||||
.leftJoin(users, eq(clients.userId, users.userId))
|
||||
.where(and(...conditions));
|
||||
.leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId));
|
||||
}
|
||||
|
||||
async function getSiteAssociations(clientIds: number[]) {
|
||||
@@ -160,29 +218,26 @@ async function getSiteAssociations(clientIds: number[]) {
|
||||
.where(inArray(clientSitesAssociationsCache.clientId, clientIds));
|
||||
}
|
||||
|
||||
type OlmWithUpdateAvailable = Awaited<ReturnType<typeof queryClients>>[0] & {
|
||||
type ClientWithSites = Awaited<ReturnType<typeof queryClientsBase>>[0] & {
|
||||
sites: Array<{
|
||||
siteId: number;
|
||||
siteName: string | null;
|
||||
siteNiceId: string | null;
|
||||
}>;
|
||||
olmUpdateAvailable?: boolean;
|
||||
};
|
||||
|
||||
export type ListClientsResponse = {
|
||||
clients: Array<
|
||||
Awaited<ReturnType<typeof queryClients>>[0] & {
|
||||
sites: Array<{
|
||||
siteId: number;
|
||||
siteName: string | null;
|
||||
siteNiceId: string | null;
|
||||
}>;
|
||||
olmUpdateAvailable?: boolean;
|
||||
}
|
||||
>;
|
||||
pagination: { total: number; limit: number; offset: number };
|
||||
};
|
||||
type OlmWithUpdateAvailable = ClientWithSites;
|
||||
|
||||
export type ListClientsResponse = PaginatedResponse<{
|
||||
clients: Array<ClientWithSites>;
|
||||
}>;
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/clients",
|
||||
description: "List all clients for an organization.",
|
||||
tags: [OpenAPITags.Client, OpenAPITags.Org],
|
||||
tags: [OpenAPITags.Client],
|
||||
request: {
|
||||
query: listClientsSchema,
|
||||
params: listClientsParamsSchema
|
||||
@@ -205,7 +260,8 @@ export async function listClients(
|
||||
)
|
||||
);
|
||||
}
|
||||
const { limit, offset, filter } = parsedQuery.data;
|
||||
const { page, pageSize, online, query, status, sort_by, order } =
|
||||
parsedQuery.data;
|
||||
|
||||
const parsedParams = listClientsParamsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
@@ -241,7 +297,7 @@ export async function listClients(
|
||||
.where(
|
||||
or(
|
||||
eq(userClients.userId, req.user!.userId),
|
||||
eq(roleClients.roleId, req.userOrgRoleId!)
|
||||
inArray(roleClients.roleId, req.userOrgRoleIds!)
|
||||
)
|
||||
);
|
||||
} else {
|
||||
@@ -254,28 +310,73 @@ export async function listClients(
|
||||
const accessibleClientIds = accessibleClients.map(
|
||||
(client) => client.clientId
|
||||
);
|
||||
const baseQuery = queryClients(orgId, accessibleClientIds, filter);
|
||||
|
||||
// Get client count with filter
|
||||
const countConditions = [
|
||||
inArray(clients.clientId, accessibleClientIds),
|
||||
eq(clients.orgId, orgId)
|
||||
const conditions = [
|
||||
and(
|
||||
inArray(clients.clientId, accessibleClientIds),
|
||||
eq(clients.orgId, orgId),
|
||||
isNull(clients.userId)
|
||||
)
|
||||
];
|
||||
|
||||
if (filter === "user") {
|
||||
countConditions.push(isNotNull(clients.userId));
|
||||
} else if (filter === "machine") {
|
||||
countConditions.push(isNull(clients.userId));
|
||||
if (typeof online !== "undefined") {
|
||||
conditions.push(eq(clients.online, online));
|
||||
}
|
||||
|
||||
const countQuery = db
|
||||
.select({ count: count() })
|
||||
.from(clients)
|
||||
.where(and(...countConditions));
|
||||
if (status.length > 0) {
|
||||
const filterAggregates: (SQL<unknown> | undefined)[] = [];
|
||||
|
||||
const clientsList = await baseQuery.limit(limit).offset(offset);
|
||||
const totalCountResult = await countQuery;
|
||||
const totalCount = totalCountResult[0].count;
|
||||
if (status.includes("active")) {
|
||||
filterAggregates.push(
|
||||
and(eq(clients.archived, false), eq(clients.blocked, false))
|
||||
);
|
||||
}
|
||||
|
||||
if (status.includes("archived")) {
|
||||
filterAggregates.push(eq(clients.archived, true));
|
||||
}
|
||||
if (status.includes("blocked")) {
|
||||
filterAggregates.push(eq(clients.blocked, true));
|
||||
}
|
||||
|
||||
conditions.push(or(...filterAggregates));
|
||||
}
|
||||
|
||||
if (query) {
|
||||
conditions.push(
|
||||
or(
|
||||
like(
|
||||
sql`LOWER(${clients.name})`,
|
||||
"%" + query.toLowerCase() + "%"
|
||||
),
|
||||
like(
|
||||
sql`LOWER(${clients.niceId})`,
|
||||
"%" + query.toLowerCase() + "%"
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const baseQuery = queryClientsBase().where(and(...conditions));
|
||||
|
||||
const countQuery = db.$count(baseQuery.as("filtered_clients"));
|
||||
|
||||
const listMachinesQuery = baseQuery
|
||||
.limit(pageSize)
|
||||
.offset(pageSize * (page - 1))
|
||||
.orderBy(
|
||||
sort_by
|
||||
? order === "asc"
|
||||
? asc(clients[sort_by])
|
||||
: desc(clients[sort_by])
|
||||
: asc(clients.name)
|
||||
);
|
||||
|
||||
const [clientsList, totalCount] = await Promise.all([
|
||||
listMachinesQuery,
|
||||
countQuery
|
||||
]);
|
||||
|
||||
// Get associated sites for all clients
|
||||
const clientIds = clientsList.map((client) => client.clientId);
|
||||
@@ -304,11 +405,13 @@ export async function listClients(
|
||||
>
|
||||
);
|
||||
|
||||
// Merge clients with their site associations
|
||||
const clientsWithSites = clientsList.map((client) => ({
|
||||
...client,
|
||||
sites: sitesByClient[client.clientId] || []
|
||||
}));
|
||||
// Merge clients with their site associations and replace name with device name
|
||||
const clientsWithSites = clientsList.map((client) => {
|
||||
return {
|
||||
...client,
|
||||
sites: sitesByClient[client.clientId] || []
|
||||
};
|
||||
});
|
||||
|
||||
const latestOlVersionPromise = getLatestOlmVersion();
|
||||
|
||||
@@ -347,11 +450,11 @@ export async function listClients(
|
||||
|
||||
return response<ListClientsResponse>(res, {
|
||||
data: {
|
||||
clients: clientsWithSites,
|
||||
clients: olmsWithUpdates,
|
||||
pagination: {
|
||||
total: totalCount,
|
||||
limit,
|
||||
offset
|
||||
page,
|
||||
pageSize
|
||||
}
|
||||
},
|
||||
success: true,
|
||||
|
||||
500
server/routers/client/listUserDevices.ts
Normal file
500
server/routers/client/listUserDevices.ts
Normal file
@@ -0,0 +1,500 @@
|
||||
import { build } from "@server/build";
|
||||
import {
|
||||
clients,
|
||||
currentFingerprint,
|
||||
db,
|
||||
olms,
|
||||
orgs,
|
||||
roleClients,
|
||||
userClients,
|
||||
users
|
||||
} from "@server/db";
|
||||
import { getUserDeviceName } from "@server/db/names";
|
||||
import response from "@server/lib/response";
|
||||
import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import type { PaginatedResponse } from "@server/types/Pagination";
|
||||
import {
|
||||
and,
|
||||
asc,
|
||||
desc,
|
||||
eq,
|
||||
inArray,
|
||||
isNotNull,
|
||||
isNull,
|
||||
like,
|
||||
or,
|
||||
sql,
|
||||
type SQL
|
||||
} from "drizzle-orm";
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import createHttpError from "http-errors";
|
||||
import NodeCache from "node-cache";
|
||||
import semver from "semver";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
|
||||
const olmVersionCache = new NodeCache({ stdTTL: 3600 });
|
||||
|
||||
async function getLatestOlmVersion(): Promise<string | null> {
|
||||
try {
|
||||
const cachedVersion = olmVersionCache.get<string>("latestOlmVersion");
|
||||
if (cachedVersion) {
|
||||
return cachedVersion;
|
||||
}
|
||||
|
||||
const controller = new AbortController();
|
||||
const timeoutId = setTimeout(() => controller.abort(), 1500);
|
||||
|
||||
const response = await fetch(
|
||||
"https://api.github.com/repos/fosrl/olm/tags",
|
||||
{
|
||||
signal: controller.signal
|
||||
}
|
||||
);
|
||||
|
||||
clearTimeout(timeoutId);
|
||||
|
||||
if (!response.ok) {
|
||||
logger.warn(
|
||||
`Failed to fetch latest Olm version from GitHub: ${response.status} ${response.statusText}`
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
let tags = await response.json();
|
||||
if (!Array.isArray(tags) || tags.length === 0) {
|
||||
logger.warn("No tags found for Olm repository");
|
||||
return null;
|
||||
}
|
||||
tags = tags.filter((version) => !version.name.includes("rc"));
|
||||
const latestVersion = tags[0].name;
|
||||
|
||||
olmVersionCache.set("latestOlmVersion", latestVersion, 3600);
|
||||
|
||||
return latestVersion;
|
||||
} catch (error: any) {
|
||||
if (error.name === "AbortError") {
|
||||
logger.warn("Request to fetch latest Olm version timed out (1.5s)");
|
||||
} else if (error.cause?.code === "UND_ERR_CONNECT_TIMEOUT") {
|
||||
logger.warn("Connection timeout while fetching latest Olm version");
|
||||
} else {
|
||||
logger.warn(
|
||||
"Error fetching latest Olm version:",
|
||||
error.message || error
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
const listUserDevicesParamsSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
const listUserDevicesSchema = z.object({
|
||||
pageSize: z.coerce
|
||||
.number<string>() // for prettier formatting
|
||||
.int()
|
||||
.positive()
|
||||
.optional()
|
||||
.catch(20)
|
||||
.default(20)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 20,
|
||||
description: "Number of items per page"
|
||||
}),
|
||||
page: z.coerce
|
||||
.number<string>() // for prettier formatting
|
||||
.int()
|
||||
.min(0)
|
||||
.optional()
|
||||
.catch(1)
|
||||
.default(1)
|
||||
.openapi({
|
||||
type: "integer",
|
||||
default: 1,
|
||||
description: "Page number to retrieve"
|
||||
}),
|
||||
query: z.string().optional(),
|
||||
sort_by: z
|
||||
.enum(["megabytesIn", "megabytesOut"])
|
||||
.optional()
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["megabytesIn", "megabytesOut"],
|
||||
description: "Field to sort by"
|
||||
}),
|
||||
order: z
|
||||
.enum(["asc", "desc"])
|
||||
.optional()
|
||||
.default("asc")
|
||||
.catch("asc")
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: ["asc", "desc"],
|
||||
default: "asc",
|
||||
description: "Sort order"
|
||||
}),
|
||||
online: z
|
||||
.enum(["true", "false"])
|
||||
.transform((v) => v === "true")
|
||||
.optional()
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "boolean",
|
||||
description: "Filter by online status"
|
||||
}),
|
||||
agent: z
|
||||
.enum([
|
||||
"windows",
|
||||
"android",
|
||||
"cli",
|
||||
"olm",
|
||||
"macos",
|
||||
"ios",
|
||||
"ipados",
|
||||
"unknown"
|
||||
])
|
||||
.optional()
|
||||
.catch(undefined)
|
||||
.openapi({
|
||||
type: "string",
|
||||
enum: [
|
||||
"windows",
|
||||
"android",
|
||||
"cli",
|
||||
"olm",
|
||||
"macos",
|
||||
"ios",
|
||||
"ipados",
|
||||
"unknown"
|
||||
],
|
||||
description:
|
||||
"Filter by agent type. Use 'unknown' to filter clients with no agent detected."
|
||||
}),
|
||||
status: z.preprocess(
|
||||
(val: string | undefined) => {
|
||||
if (val) {
|
||||
return val.split(","); // the search query array is an array joined by commas
|
||||
}
|
||||
return undefined;
|
||||
},
|
||||
z
|
||||
.array(
|
||||
z.enum(["active", "pending", "denied", "blocked", "archived"])
|
||||
)
|
||||
.optional()
|
||||
.default(["active", "pending"])
|
||||
.catch(["active", "pending"])
|
||||
.openapi({
|
||||
type: "array",
|
||||
items: {
|
||||
type: "string",
|
||||
enum: ["active", "pending", "denied", "blocked", "archived"]
|
||||
},
|
||||
default: ["active", "pending"],
|
||||
description:
|
||||
"Filter by device status. Can include multiple values separated by commas. 'active' means not archived, not blocked, and if approval is enabled, approved. 'pending' and 'denied' are only applicable if approval is enabled."
|
||||
})
|
||||
)
|
||||
});
|
||||
|
||||
function queryUserDevicesBase() {
|
||||
return db
|
||||
.select({
|
||||
clientId: clients.clientId,
|
||||
orgId: clients.orgId,
|
||||
name: clients.name,
|
||||
pubKey: clients.pubKey,
|
||||
subnet: clients.subnet,
|
||||
megabytesIn: clients.megabytesIn,
|
||||
megabytesOut: clients.megabytesOut,
|
||||
orgName: orgs.name,
|
||||
type: clients.type,
|
||||
online: clients.online,
|
||||
olmVersion: olms.version,
|
||||
userId: clients.userId,
|
||||
username: users.username,
|
||||
userEmail: users.email,
|
||||
niceId: clients.niceId,
|
||||
agent: olms.agent,
|
||||
approvalState: clients.approvalState,
|
||||
olmArchived: olms.archived,
|
||||
archived: clients.archived,
|
||||
blocked: clients.blocked,
|
||||
deviceModel: currentFingerprint.deviceModel,
|
||||
fingerprintPlatform: currentFingerprint.platform,
|
||||
fingerprintOsVersion: currentFingerprint.osVersion,
|
||||
fingerprintKernelVersion: currentFingerprint.kernelVersion,
|
||||
fingerprintArch: currentFingerprint.arch,
|
||||
fingerprintSerialNumber: currentFingerprint.serialNumber,
|
||||
fingerprintUsername: currentFingerprint.username,
|
||||
fingerprintHostname: currentFingerprint.hostname
|
||||
})
|
||||
.from(clients)
|
||||
.leftJoin(orgs, eq(clients.orgId, orgs.orgId))
|
||||
.leftJoin(olms, eq(clients.clientId, olms.clientId))
|
||||
.leftJoin(users, eq(clients.userId, users.userId))
|
||||
.leftJoin(currentFingerprint, eq(olms.olmId, currentFingerprint.olmId));
|
||||
}
|
||||
|
||||
type OlmWithUpdateAvailable = Awaited<
|
||||
ReturnType<typeof queryUserDevicesBase>
|
||||
>[0] & {
|
||||
olmUpdateAvailable?: boolean;
|
||||
};
|
||||
|
||||
export type ListUserDevicesResponse = PaginatedResponse<{
|
||||
devices: Array<OlmWithUpdateAvailable>;
|
||||
}>;
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/user-devices",
|
||||
description: "List all user devices for an organization.",
|
||||
tags: [OpenAPITags.Client],
|
||||
request: {
|
||||
query: listUserDevicesSchema,
|
||||
params: listUserDevicesParamsSchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function listUserDevices(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedQuery = listUserDevicesSchema.safeParse(req.query);
|
||||
if (!parsedQuery.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedQuery.error)
|
||||
)
|
||||
);
|
||||
}
|
||||
const { page, pageSize, query, sort_by, online, status, agent, order } =
|
||||
parsedQuery.data;
|
||||
|
||||
const parsedParams = listUserDevicesParamsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error)
|
||||
)
|
||||
);
|
||||
}
|
||||
const { orgId } = parsedParams.data;
|
||||
|
||||
if (req.user && orgId && orgId !== req.userOrgId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"User does not have access to this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let accessibleClients;
|
||||
if (req.user) {
|
||||
accessibleClients = await db
|
||||
.select({
|
||||
clientId: sql<number>`COALESCE(${userClients.clientId}, ${roleClients.clientId})`
|
||||
})
|
||||
.from(userClients)
|
||||
.fullJoin(
|
||||
roleClients,
|
||||
eq(userClients.clientId, roleClients.clientId)
|
||||
)
|
||||
.where(
|
||||
or(
|
||||
eq(userClients.userId, req.user!.userId),
|
||||
inArray(roleClients.roleId, req.userOrgRoleIds!)
|
||||
)
|
||||
);
|
||||
} else {
|
||||
accessibleClients = await db
|
||||
.select({ clientId: clients.clientId })
|
||||
.from(clients)
|
||||
.where(eq(clients.orgId, orgId));
|
||||
}
|
||||
|
||||
const accessibleClientIds = accessibleClients.map(
|
||||
(client) => client.clientId
|
||||
);
|
||||
// Get client count with filter
|
||||
const conditions = [
|
||||
and(
|
||||
inArray(clients.clientId, accessibleClientIds),
|
||||
eq(clients.orgId, orgId),
|
||||
isNotNull(clients.userId)
|
||||
)
|
||||
];
|
||||
|
||||
if (query) {
|
||||
conditions.push(
|
||||
or(
|
||||
like(
|
||||
sql`LOWER(${clients.name})`,
|
||||
"%" + query.toLowerCase() + "%"
|
||||
),
|
||||
like(
|
||||
sql`LOWER(${clients.niceId})`,
|
||||
"%" + query.toLowerCase() + "%"
|
||||
),
|
||||
like(
|
||||
sql`LOWER(${users.email})`,
|
||||
"%" + query.toLowerCase() + "%"
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (typeof online !== "undefined") {
|
||||
conditions.push(eq(clients.online, online));
|
||||
}
|
||||
|
||||
const agentValueMap = {
|
||||
windows: "Pangolin Windows",
|
||||
android: "Pangolin Android",
|
||||
ios: "Pangolin iOS",
|
||||
ipados: "Pangolin iPadOS",
|
||||
macos: "Pangolin macOS",
|
||||
cli: "Pangolin CLI",
|
||||
olm: "Olm CLI"
|
||||
} satisfies Record<
|
||||
Exclude<typeof agent, undefined | "unknown">,
|
||||
string
|
||||
>;
|
||||
if (typeof agent !== "undefined") {
|
||||
if (agent === "unknown") {
|
||||
conditions.push(isNull(olms.agent));
|
||||
} else {
|
||||
conditions.push(eq(olms.agent, agentValueMap[agent]));
|
||||
}
|
||||
}
|
||||
|
||||
if (status.length > 0) {
|
||||
const filterAggregates: (SQL<unknown> | undefined)[] = [];
|
||||
|
||||
if (status.includes("active")) {
|
||||
filterAggregates.push(
|
||||
and(
|
||||
eq(clients.archived, false),
|
||||
eq(clients.blocked, false),
|
||||
build !== "oss"
|
||||
? or(
|
||||
eq(clients.approvalState, "approved"),
|
||||
isNull(clients.approvalState) // approval state of `NULL` means approved by default
|
||||
)
|
||||
: undefined // undefined are automatically ignored by `drizzle-orm`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (status.includes("archived")) {
|
||||
filterAggregates.push(eq(clients.archived, true));
|
||||
}
|
||||
if (status.includes("blocked")) {
|
||||
filterAggregates.push(eq(clients.blocked, true));
|
||||
}
|
||||
|
||||
if (build !== "oss") {
|
||||
if (status.includes("pending")) {
|
||||
filterAggregates.push(eq(clients.approvalState, "pending"));
|
||||
}
|
||||
if (status.includes("denied")) {
|
||||
filterAggregates.push(eq(clients.approvalState, "denied"));
|
||||
}
|
||||
}
|
||||
|
||||
conditions.push(or(...filterAggregates));
|
||||
}
|
||||
|
||||
const baseQuery = queryUserDevicesBase().where(and(...conditions));
|
||||
|
||||
const countQuery = db.$count(baseQuery.as("filtered_clients"));
|
||||
|
||||
const listDevicesQuery = baseQuery
|
||||
.limit(pageSize)
|
||||
.offset(pageSize * (page - 1))
|
||||
.orderBy(
|
||||
sort_by
|
||||
? order === "asc"
|
||||
? asc(clients[sort_by])
|
||||
: desc(clients[sort_by])
|
||||
: asc(clients.clientId)
|
||||
);
|
||||
|
||||
const [clientsList, totalCount] = await Promise.all([
|
||||
listDevicesQuery,
|
||||
countQuery
|
||||
]);
|
||||
|
||||
// Merge clients with their site associations and replace name with device name
|
||||
const olmsWithUpdates: OlmWithUpdateAvailable[] = clientsList.map(
|
||||
(client) => {
|
||||
const model = client.deviceModel || null;
|
||||
const newName = getUserDeviceName(model, client.name);
|
||||
const OlmWithUpdate: OlmWithUpdateAvailable = {
|
||||
...client,
|
||||
name: newName
|
||||
};
|
||||
// Initially set to false, will be updated if version check succeeds
|
||||
OlmWithUpdate.olmUpdateAvailable = false;
|
||||
return OlmWithUpdate;
|
||||
}
|
||||
);
|
||||
|
||||
// Try to get the latest version, but don't block if it fails
|
||||
try {
|
||||
const latestOlmVersion = await getLatestOlmVersion();
|
||||
|
||||
if (latestOlmVersion) {
|
||||
olmsWithUpdates.forEach((client) => {
|
||||
try {
|
||||
client.olmUpdateAvailable = semver.lt(
|
||||
client.olmVersion ? client.olmVersion : "",
|
||||
latestOlmVersion
|
||||
);
|
||||
} catch (error) {
|
||||
client.olmUpdateAvailable = false;
|
||||
}
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
// Log the error but don't let it block the response
|
||||
logger.warn(
|
||||
"Failed to check for OLM updates, continuing without update info:",
|
||||
error
|
||||
);
|
||||
}
|
||||
|
||||
return response<ListUserDevicesResponse>(res, {
|
||||
data: {
|
||||
devices: olmsWithUpdates,
|
||||
pagination: {
|
||||
total: totalCount,
|
||||
page,
|
||||
pageSize
|
||||
}
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Clients retrieved successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -23,7 +23,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/pick-client-defaults",
|
||||
description: "Return pre-requisite data for creating a client.",
|
||||
tags: [OpenAPITags.Client, OpenAPITags.Site],
|
||||
tags: [OpenAPITags.Client],
|
||||
request: {
|
||||
params: pickClientDefaultsSchema
|
||||
},
|
||||
|
||||
@@ -1,37 +1,128 @@
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import { db, olms, Transaction } from "@server/db";
|
||||
import { Alias, SubnetProxyTarget } from "@server/lib/ip";
|
||||
import { db, newts, olms } from "@server/db";
|
||||
import {
|
||||
Alias,
|
||||
convertSubnetProxyTargetsV2ToV1,
|
||||
SubnetProxyTarget,
|
||||
SubnetProxyTargetV2
|
||||
} from "@server/lib/ip";
|
||||
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||
import logger from "@server/logger";
|
||||
import { eq } from "drizzle-orm";
|
||||
import semver from "semver";
|
||||
|
||||
export async function addTargets(newtId: string, targets: SubnetProxyTarget[]) {
|
||||
await sendToClient(newtId, {
|
||||
type: `newt/wg/targets/add`,
|
||||
data: targets
|
||||
});
|
||||
const NEWT_V2_TARGETS_VERSION = ">=1.10.3";
|
||||
|
||||
export async function convertTargetsIfNessicary(
|
||||
newtId: string,
|
||||
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[]
|
||||
) {
|
||||
// get the newt
|
||||
const [newt] = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.newtId, newtId));
|
||||
if (!newt) {
|
||||
throw new Error(`No newt found for id: ${newtId}`);
|
||||
}
|
||||
|
||||
// check the semver
|
||||
if (
|
||||
newt.version &&
|
||||
!semver.satisfies(newt.version, NEWT_V2_TARGETS_VERSION)
|
||||
) {
|
||||
logger.debug(
|
||||
`addTargets Newt version ${newt.version} does not support targets v2 falling back`
|
||||
);
|
||||
targets = convertSubnetProxyTargetsV2ToV1(
|
||||
targets as SubnetProxyTargetV2[]
|
||||
);
|
||||
}
|
||||
|
||||
return targets;
|
||||
}
|
||||
|
||||
export async function addTargets(
|
||||
newtId: string,
|
||||
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[],
|
||||
version?: string | null
|
||||
) {
|
||||
targets = await convertTargetsIfNessicary(newtId, targets);
|
||||
|
||||
await sendToClient(
|
||||
newtId,
|
||||
{
|
||||
type: `newt/wg/targets/add`,
|
||||
data: targets
|
||||
},
|
||||
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
|
||||
);
|
||||
}
|
||||
|
||||
export async function removeTargets(
|
||||
newtId: string,
|
||||
targets: SubnetProxyTarget[]
|
||||
targets: SubnetProxyTarget[] | SubnetProxyTargetV2[],
|
||||
version?: string | null
|
||||
) {
|
||||
await sendToClient(newtId, {
|
||||
type: `newt/wg/targets/remove`,
|
||||
data: targets
|
||||
});
|
||||
targets = await convertTargetsIfNessicary(newtId, targets);
|
||||
|
||||
await sendToClient(
|
||||
newtId,
|
||||
{
|
||||
type: `newt/wg/targets/remove`,
|
||||
data: targets
|
||||
},
|
||||
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
|
||||
);
|
||||
}
|
||||
|
||||
export async function updateTargets(
|
||||
newtId: string,
|
||||
targets: {
|
||||
oldTargets: SubnetProxyTarget[];
|
||||
newTargets: SubnetProxyTarget[];
|
||||
}
|
||||
oldTargets: SubnetProxyTarget[] | SubnetProxyTargetV2[];
|
||||
newTargets: SubnetProxyTarget[] | SubnetProxyTargetV2[];
|
||||
},
|
||||
version?: string | null
|
||||
) {
|
||||
await sendToClient(newtId, {
|
||||
type: `newt/wg/targets/update`,
|
||||
data: targets
|
||||
}).catch((error) => {
|
||||
// get the newt
|
||||
const [newt] = await db
|
||||
.select()
|
||||
.from(newts)
|
||||
.where(eq(newts.newtId, newtId));
|
||||
if (!newt) {
|
||||
logger.error(`addTargetsL No newt found for id: ${newtId}`);
|
||||
return;
|
||||
}
|
||||
|
||||
// check the semver
|
||||
if (
|
||||
newt.version &&
|
||||
!semver.satisfies(newt.version, NEWT_V2_TARGETS_VERSION)
|
||||
) {
|
||||
logger.debug(
|
||||
`addTargets Newt version ${newt.version} does not support targets v2 falling back`
|
||||
);
|
||||
targets = {
|
||||
oldTargets: convertSubnetProxyTargetsV2ToV1(
|
||||
targets.oldTargets as SubnetProxyTargetV2[]
|
||||
),
|
||||
newTargets: convertSubnetProxyTargetsV2ToV1(
|
||||
targets.newTargets as SubnetProxyTargetV2[]
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
await sendToClient(
|
||||
newtId,
|
||||
{
|
||||
type: `newt/wg/targets/update`,
|
||||
data: {
|
||||
oldTargets: targets.oldTargets,
|
||||
newTargets: targets.newTargets
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: true, compress: canCompress(version, "newt") }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
@@ -41,7 +132,8 @@ export async function addPeerData(
|
||||
siteId: number,
|
||||
remoteSubnets: string[],
|
||||
aliases: Alias[],
|
||||
olmId?: string
|
||||
olmId?: string,
|
||||
version?: string | null
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
@@ -53,16 +145,21 @@ export async function addPeerData(
|
||||
return; // ignore this because an olm might not be associated with the client anymore
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
version = olm.version;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: `olm/wg/peer/data/add`,
|
||||
data: {
|
||||
siteId: siteId,
|
||||
remoteSubnets: remoteSubnets,
|
||||
aliases: aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
await sendToClient(
|
||||
olmId,
|
||||
{
|
||||
type: `olm/wg/peer/data/add`,
|
||||
data: {
|
||||
siteId: siteId,
|
||||
remoteSubnets: remoteSubnets,
|
||||
aliases: aliases
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
@@ -72,7 +169,8 @@ export async function removePeerData(
|
||||
siteId: number,
|
||||
remoteSubnets: string[],
|
||||
aliases: Alias[],
|
||||
olmId?: string
|
||||
olmId?: string,
|
||||
version?: string | null
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
@@ -84,16 +182,21 @@ export async function removePeerData(
|
||||
return;
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
version = olm.version;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: `olm/wg/peer/data/remove`,
|
||||
data: {
|
||||
siteId: siteId,
|
||||
remoteSubnets: remoteSubnets,
|
||||
aliases: aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
await sendToClient(
|
||||
olmId,
|
||||
{
|
||||
type: `olm/wg/peer/data/remove`,
|
||||
data: {
|
||||
siteId: siteId,
|
||||
remoteSubnets: remoteSubnets,
|
||||
aliases: aliases
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
@@ -113,7 +216,8 @@ export async function updatePeerData(
|
||||
newAliases: Alias[];
|
||||
}
|
||||
| undefined,
|
||||
olmId?: string
|
||||
olmId?: string,
|
||||
version?: string | null
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
@@ -125,16 +229,21 @@ export async function updatePeerData(
|
||||
return;
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
version = olm.version;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: `olm/wg/peer/data/update`,
|
||||
data: {
|
||||
siteId: siteId,
|
||||
...remoteSubnets,
|
||||
...aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
await sendToClient(
|
||||
olmId,
|
||||
{
|
||||
type: `olm/wg/peer/data/update`,
|
||||
data: {
|
||||
siteId: siteId,
|
||||
...remoteSubnets,
|
||||
...aliases
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import { db, olms } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { OlmErrorCodes } from "../olm/error";
|
||||
|
||||
export async function sendTerminateClient(
|
||||
clientId: number,
|
||||
error: (typeof OlmErrorCodes)[keyof typeof OlmErrorCodes],
|
||||
olmId?: string | null
|
||||
) {
|
||||
if (!olmId) {
|
||||
@@ -20,6 +22,9 @@ export async function sendTerminateClient(
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: `olm/terminate`,
|
||||
data: {}
|
||||
data: {
|
||||
code: error.code,
|
||||
message: error.message
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
93
server/routers/client/unarchiveClient.ts
Normal file
93
server/routers/client/unarchiveClient.ts
Normal file
@@ -0,0 +1,93 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { clients } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const unarchiveClientSchema = z.strictObject({
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/client/{clientId}/unarchive",
|
||||
description: "Unarchive a client by its client ID.",
|
||||
tags: [OpenAPITags.Client],
|
||||
request: {
|
||||
params: unarchiveClientSchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function unarchiveClient(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = unarchiveClientSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { clientId } = parsedParams.data;
|
||||
|
||||
// Check if client exists
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Client with ID ${clientId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!client.archived) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Client with ID ${clientId} is not archived`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Unarchive the client
|
||||
await db
|
||||
.update(clients)
|
||||
.set({ archived: false })
|
||||
.where(eq(clients.clientId, clientId));
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Client unarchived successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to unarchive client"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
93
server/routers/client/unblockClient.ts
Normal file
93
server/routers/client/unblockClient.ts
Normal file
@@ -0,0 +1,93 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import { clients } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const unblockClientSchema = z.strictObject({
|
||||
clientId: z.string().transform(Number).pipe(z.int().positive())
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "post",
|
||||
path: "/client/{clientId}/unblock",
|
||||
description: "Unblock a client by its client ID.",
|
||||
tags: [OpenAPITags.Client],
|
||||
request: {
|
||||
params: unblockClientSchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
|
||||
export async function unblockClient(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = unblockClientSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { clientId } = parsedParams.data;
|
||||
|
||||
// Check if client exists
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
`Client with ID ${clientId} not found`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (!client.blocked) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
`Client with ID ${clientId} is not blocked`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Unblock the client
|
||||
await db
|
||||
.update(clients)
|
||||
.set({ blocked: false, approvalState: null })
|
||||
.where(eq(clients.clientId, clientId));
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Client unblocked successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to unblock client"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { eq, and, ne } from "drizzle-orm";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
@@ -93,7 +93,8 @@ export async function updateClient(
|
||||
.where(
|
||||
and(
|
||||
eq(clients.niceId, niceId),
|
||||
eq(clients.orgId, clients.orgId)
|
||||
eq(clients.orgId, clients.orgId),
|
||||
ne(clients.clientId, clientId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
@@ -131,7 +131,7 @@ export async function createOrgDomain(
|
||||
}
|
||||
const rejectDomains = await usageService.checkLimitSet(
|
||||
orgId,
|
||||
false,
|
||||
|
||||
FeatureId.DOMAINS,
|
||||
{
|
||||
...usage,
|
||||
@@ -148,7 +148,6 @@ export async function createOrgDomain(
|
||||
}
|
||||
}
|
||||
|
||||
let numOrgDomains: OrgDomains[] | undefined;
|
||||
let aRecords: CreateDomainResponse["aRecords"];
|
||||
let cnameRecords: CreateDomainResponse["cnameRecords"];
|
||||
let txtRecords: CreateDomainResponse["txtRecords"];
|
||||
@@ -347,20 +346,9 @@ export async function createOrgDomain(
|
||||
await trx.insert(dnsRecords).values(recordsToInsert);
|
||||
}
|
||||
|
||||
numOrgDomains = await trx
|
||||
.select()
|
||||
.from(orgDomains)
|
||||
.where(eq(orgDomains.orgId, orgId));
|
||||
await usageService.add(orgId, FeatureId.DOMAINS, 1, trx);
|
||||
});
|
||||
|
||||
if (numOrgDomains) {
|
||||
await usageService.updateDaily(
|
||||
orgId,
|
||||
FeatureId.DOMAINS,
|
||||
numOrgDomains.length
|
||||
);
|
||||
}
|
||||
|
||||
if (!returned) {
|
||||
return next(
|
||||
createHttpError(
|
||||
|
||||
@@ -36,8 +36,6 @@ export async function deleteAccountDomain(
|
||||
}
|
||||
const { domainId, orgId } = parsed.data;
|
||||
|
||||
let numOrgDomains: OrgDomains[] | undefined;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
const [existing] = await trx
|
||||
.select()
|
||||
@@ -79,20 +77,9 @@ export async function deleteAccountDomain(
|
||||
|
||||
await trx.delete(domains).where(eq(domains.domainId, domainId));
|
||||
|
||||
numOrgDomains = await trx
|
||||
.select()
|
||||
.from(orgDomains)
|
||||
.where(eq(orgDomains.orgId, orgId));
|
||||
await usageService.add(orgId, FeatureId.DOMAINS, -1, trx);
|
||||
});
|
||||
|
||||
if (numOrgDomains) {
|
||||
await usageService.updateDaily(
|
||||
orgId,
|
||||
FeatureId.DOMAINS,
|
||||
numOrgDomains.length
|
||||
);
|
||||
}
|
||||
|
||||
return response<DeleteAccountDomainResponse>(res, {
|
||||
data: { success: true },
|
||||
success: true,
|
||||
|
||||
@@ -40,7 +40,8 @@ async function queryDomains(orgId: string, limit: number, offset: number) {
|
||||
tries: domains.tries,
|
||||
configManaged: domains.configManaged,
|
||||
certResolver: domains.certResolver,
|
||||
preferWildcardCert: domains.preferWildcardCert
|
||||
preferWildcardCert: domains.preferWildcardCert,
|
||||
errorMessage: domains.errorMessage
|
||||
})
|
||||
.from(orgDomains)
|
||||
.where(eq(orgDomains.orgId, orgId))
|
||||
@@ -59,7 +60,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/domains",
|
||||
description: "List all domains for a organization.",
|
||||
tags: [OpenAPITags.Org],
|
||||
tags: [OpenAPITags.Domain],
|
||||
request: {
|
||||
params: z.object({
|
||||
orgId: z.string()
|
||||
|
||||
@@ -18,6 +18,7 @@ import * as apiKeys from "./apiKeys";
|
||||
import * as logs from "./auditLogs";
|
||||
import * as newt from "./newt";
|
||||
import * as olm from "./olm";
|
||||
import * as serverInfo from "./serverInfo";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import {
|
||||
verifyAccessTokenAccess,
|
||||
@@ -40,7 +41,8 @@ import {
|
||||
verifyUserHasAction,
|
||||
verifyUserIsOrgOwner,
|
||||
verifySiteResourceAccess,
|
||||
verifyOlmAccess
|
||||
verifyOlmAccess,
|
||||
verifyLimits
|
||||
} from "@server/middlewares";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
import rateLimit, { ipKeyGenerator } from "express-rate-limit";
|
||||
@@ -48,7 +50,7 @@ import createHttpError from "http-errors";
|
||||
import { build } from "@server/build";
|
||||
import { createStore } from "#dynamic/lib/rateLimitStore";
|
||||
import { logActionAudit } from "#dynamic/middlewares";
|
||||
import { log } from "console";
|
||||
import { checkRoundTripMessage } from "./ws";
|
||||
|
||||
// Root routes
|
||||
export const unauthenticated = Router();
|
||||
@@ -63,9 +65,8 @@ authenticated.use(verifySessionUserMiddleware);
|
||||
|
||||
authenticated.get("/pick-org-defaults", org.pickOrgDefaults);
|
||||
authenticated.get("/org/checkId", org.checkId);
|
||||
if (build === "oss" || build === "enterprise") {
|
||||
authenticated.put("/org", getUserOrgs, org.createOrg);
|
||||
}
|
||||
|
||||
authenticated.put("/org", getUserOrgs, org.createOrg);
|
||||
|
||||
authenticated.get("/orgs", verifyUserIsServerAdmin, org.listOrgs);
|
||||
authenticated.get("/user/:userId/orgs", verifyIsLoggedInUser, org.listUserOrgs);
|
||||
@@ -79,21 +80,20 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
"/org/:orgId",
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateOrg),
|
||||
logActionAudit(ActionsEnum.updateOrg),
|
||||
org.updateOrg
|
||||
);
|
||||
|
||||
if (build !== "saas") {
|
||||
authenticated.delete(
|
||||
"/org/:orgId",
|
||||
verifyOrgAccess,
|
||||
verifyUserIsOrgOwner,
|
||||
verifyUserHasAction(ActionsEnum.deleteOrg),
|
||||
logActionAudit(ActionsEnum.deleteOrg),
|
||||
org.deleteOrg
|
||||
);
|
||||
}
|
||||
authenticated.delete(
|
||||
"/org/:orgId",
|
||||
verifyOrgAccess,
|
||||
verifyUserIsOrgOwner,
|
||||
verifyUserHasAction(ActionsEnum.deleteOrg),
|
||||
logActionAudit(ActionsEnum.deleteOrg),
|
||||
org.deleteOrg
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/org/:orgId/site",
|
||||
@@ -102,6 +102,8 @@ authenticated.put(
|
||||
logActionAudit(ActionsEnum.createSite),
|
||||
site.createSite
|
||||
);
|
||||
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/sites",
|
||||
verifyOrgAccess,
|
||||
@@ -143,6 +145,13 @@ authenticated.get(
|
||||
client.listClients
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/user-devices",
|
||||
verifyOrgAccess,
|
||||
verifyUserHasAction(ActionsEnum.listClients),
|
||||
client.listUserDevices
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/client/:clientId",
|
||||
verifyClientAccess,
|
||||
@@ -161,6 +170,7 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
"/org/:orgId/client",
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createClient),
|
||||
logActionAudit(ActionsEnum.createClient),
|
||||
client.createClient
|
||||
@@ -175,9 +185,46 @@ authenticated.delete(
|
||||
client.deleteClient
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId/archive",
|
||||
verifyClientAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.archiveClient),
|
||||
logActionAudit(ActionsEnum.archiveClient),
|
||||
client.archiveClient
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId/unarchive",
|
||||
verifyClientAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.unarchiveClient),
|
||||
logActionAudit(ActionsEnum.unarchiveClient),
|
||||
client.unarchiveClient
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId/block",
|
||||
verifyClientAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.blockClient),
|
||||
logActionAudit(ActionsEnum.blockClient),
|
||||
client.blockClient
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId/unblock",
|
||||
verifyClientAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.unblockClient),
|
||||
logActionAudit(ActionsEnum.unblockClient),
|
||||
client.unblockClient
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId",
|
||||
verifyClientAccess, // this will check if the user has access to the client
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateClient), // this will check if the user has permission to update the client
|
||||
logActionAudit(ActionsEnum.updateClient),
|
||||
client.updateClient
|
||||
@@ -192,6 +239,7 @@ authenticated.post(
|
||||
authenticated.post(
|
||||
"/site/:siteId",
|
||||
verifySiteAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateSite),
|
||||
logActionAudit(ActionsEnum.updateSite),
|
||||
site.updateSite
|
||||
@@ -239,9 +287,9 @@ authenticated.get(
|
||||
|
||||
// Site Resource endpoints
|
||||
authenticated.put(
|
||||
"/org/:orgId/site/:siteId/resource",
|
||||
"/org/:orgId/site-resource",
|
||||
verifyOrgAccess,
|
||||
verifySiteAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createSiteResource),
|
||||
logActionAudit(ActionsEnum.createSiteResource),
|
||||
siteResource.createSiteResource
|
||||
@@ -263,28 +311,23 @@ authenticated.get(
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/site/:siteId/resource/:siteResourceId",
|
||||
verifyOrgAccess,
|
||||
verifySiteAccess,
|
||||
"/site-resource/:siteResourceId",
|
||||
verifySiteResourceAccess,
|
||||
verifyUserHasAction(ActionsEnum.getSiteResource),
|
||||
siteResource.getSiteResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/org/:orgId/site/:siteId/resource/:siteResourceId",
|
||||
verifyOrgAccess,
|
||||
verifySiteAccess,
|
||||
"/site-resource/:siteResourceId",
|
||||
verifySiteResourceAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateSiteResource),
|
||||
logActionAudit(ActionsEnum.updateSiteResource),
|
||||
siteResource.updateSiteResource
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
"/org/:orgId/site/:siteId/resource/:siteResourceId",
|
||||
verifyOrgAccess,
|
||||
verifySiteAccess,
|
||||
"/site-resource/:siteResourceId",
|
||||
verifySiteResourceAccess,
|
||||
verifyUserHasAction(ActionsEnum.deleteSiteResource),
|
||||
logActionAudit(ActionsEnum.deleteSiteResource),
|
||||
@@ -316,6 +359,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/roles",
|
||||
verifySiteResourceAccess,
|
||||
verifyRoleAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
siteResource.setSiteResourceRoles
|
||||
@@ -325,6 +369,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/users",
|
||||
verifySiteResourceAccess,
|
||||
verifySetResourceUsers,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.setSiteResourceUsers
|
||||
@@ -334,6 +379,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients",
|
||||
verifySiteResourceAccess,
|
||||
verifySetResourceClients,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.setSiteResourceClients
|
||||
@@ -343,6 +389,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients/add",
|
||||
verifySiteResourceAccess,
|
||||
verifySetResourceClients,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.addClientToSiteResource
|
||||
@@ -352,6 +399,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients/remove",
|
||||
verifySiteResourceAccess,
|
||||
verifySetResourceClients,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.removeClientFromSiteResource
|
||||
@@ -360,6 +408,7 @@ authenticated.post(
|
||||
authenticated.put(
|
||||
"/org/:orgId/resource",
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createResource),
|
||||
logActionAudit(ActionsEnum.createResource),
|
||||
resource.createResource
|
||||
@@ -474,6 +523,7 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
"/resource/:resourceId",
|
||||
verifyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateResource),
|
||||
logActionAudit(ActionsEnum.updateResource),
|
||||
resource.updateResource
|
||||
@@ -489,6 +539,7 @@ authenticated.delete(
|
||||
authenticated.put(
|
||||
"/resource/:resourceId/target",
|
||||
verifyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createTarget),
|
||||
logActionAudit(ActionsEnum.createTarget),
|
||||
target.createTarget
|
||||
@@ -503,6 +554,7 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
"/resource/:resourceId/rule",
|
||||
verifyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createResourceRule),
|
||||
logActionAudit(ActionsEnum.createResourceRule),
|
||||
resource.createResourceRule
|
||||
@@ -516,6 +568,7 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
"/resource/:resourceId/rule/:ruleId",
|
||||
verifyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateResourceRule),
|
||||
logActionAudit(ActionsEnum.updateResourceRule),
|
||||
resource.updateResourceRule
|
||||
@@ -537,6 +590,7 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
"/target/:targetId",
|
||||
verifyTargetAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateTarget),
|
||||
logActionAudit(ActionsEnum.updateTarget),
|
||||
target.updateTarget
|
||||
@@ -552,6 +606,7 @@ authenticated.delete(
|
||||
authenticated.put(
|
||||
"/org/:orgId/role",
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createRole),
|
||||
logActionAudit(ActionsEnum.createRole),
|
||||
role.createRole
|
||||
@@ -562,6 +617,15 @@ authenticated.get(
|
||||
verifyUserHasAction(ActionsEnum.listRoles),
|
||||
role.listRoles
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/role/:roleId",
|
||||
verifyRoleAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateRole),
|
||||
logActionAudit(ActionsEnum.updateRole),
|
||||
role.updateRole
|
||||
);
|
||||
// authenticated.get(
|
||||
// "/role/:roleId",
|
||||
// verifyRoleAccess,
|
||||
@@ -582,19 +646,22 @@ authenticated.delete(
|
||||
logActionAudit(ActionsEnum.deleteRole),
|
||||
role.deleteRole
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/role/:roleId/add/:userId",
|
||||
verifyRoleAccess,
|
||||
verifyUserAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.addUserRole),
|
||||
logActionAudit(ActionsEnum.addUserRole),
|
||||
user.addUserRole
|
||||
user.addUserRoleLegacy
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/resource/:resourceId/roles",
|
||||
verifyResourceAccess,
|
||||
verifyRoleAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
resource.setResourceRoles
|
||||
@@ -604,6 +671,7 @@ authenticated.post(
|
||||
"/resource/:resourceId/users",
|
||||
verifyResourceAccess,
|
||||
verifySetResourceUsers,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
resource.setResourceUsers
|
||||
@@ -612,6 +680,7 @@ authenticated.post(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/password`,
|
||||
verifyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setResourcePassword),
|
||||
logActionAudit(ActionsEnum.setResourcePassword),
|
||||
resource.setResourcePassword
|
||||
@@ -620,6 +689,7 @@ authenticated.post(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/pincode`,
|
||||
verifyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setResourcePincode),
|
||||
logActionAudit(ActionsEnum.setResourcePincode),
|
||||
resource.setResourcePincode
|
||||
@@ -628,6 +698,7 @@ authenticated.post(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/header-auth`,
|
||||
verifyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setResourceHeaderAuth),
|
||||
logActionAudit(ActionsEnum.setResourceHeaderAuth),
|
||||
resource.setResourceHeaderAuth
|
||||
@@ -636,6 +707,7 @@ authenticated.post(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/whitelist`,
|
||||
verifyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setResourceWhitelist),
|
||||
logActionAudit(ActionsEnum.setResourceWhitelist),
|
||||
resource.setResourceWhitelist
|
||||
@@ -651,6 +723,7 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/access-token`,
|
||||
verifyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.generateAccessToken),
|
||||
logActionAudit(ActionsEnum.generateAccessToken),
|
||||
accessToken.generateAccessToken
|
||||
@@ -680,6 +753,8 @@ authenticated.get(
|
||||
|
||||
authenticated.get(`/org/:orgId/overview`, verifyOrgAccess, org.getOrgOverview);
|
||||
|
||||
authenticated.get(`/server-info`, serverInfo.getServerInfo);
|
||||
|
||||
authenticated.post(
|
||||
`/supporter-key/validate`,
|
||||
supporterKey.validateSupporterKey
|
||||
@@ -739,6 +814,7 @@ authenticated.delete(
|
||||
authenticated.put(
|
||||
"/org/:orgId/user",
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createOrgUser),
|
||||
logActionAudit(ActionsEnum.createOrgUser),
|
||||
user.createOrgUser
|
||||
@@ -748,6 +824,7 @@ authenticated.post(
|
||||
"/org/:orgId/user/:userId",
|
||||
verifyOrgAccess,
|
||||
verifyUserAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.updateOrgUser),
|
||||
logActionAudit(ActionsEnum.updateOrgUser),
|
||||
user.updateOrgUser
|
||||
@@ -816,11 +893,19 @@ authenticated.put("/user/:userId/olm", verifyIsLoggedInUser, olm.createUserOlm);
|
||||
|
||||
authenticated.get("/user/:userId/olms", verifyIsLoggedInUser, olm.listUserOlms);
|
||||
|
||||
authenticated.delete(
|
||||
"/user/:userId/olm/:olmId",
|
||||
authenticated.post(
|
||||
"/user/:userId/olm/:olmId/archive",
|
||||
verifyIsLoggedInUser,
|
||||
verifyOlmAccess,
|
||||
olm.deleteUserOlm
|
||||
verifyLimits,
|
||||
olm.archiveUserOlm
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/user/:userId/olm/:olmId/unarchive",
|
||||
verifyIsLoggedInUser,
|
||||
verifyOlmAccess,
|
||||
olm.unarchiveUserOlm
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
@@ -830,6 +915,12 @@ authenticated.get(
|
||||
olm.getUserOlm
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/user/:userId/olm/recover",
|
||||
verifyIsLoggedInUser,
|
||||
olm.recoverOlmWithFingerprint
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/idp/oidc",
|
||||
verifyUserIsServerAdmin,
|
||||
@@ -921,6 +1012,7 @@ authenticated.post(
|
||||
`/org/:orgId/api-key/:apiKeyId/actions`,
|
||||
verifyOrgAccess,
|
||||
verifyApiKeyAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.setApiKeyActions),
|
||||
logActionAudit(ActionsEnum.setApiKeyActions),
|
||||
apiKeys.setApiKeyActions
|
||||
@@ -937,6 +1029,7 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
`/org/:orgId/api-key`,
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createApiKey),
|
||||
logActionAudit(ActionsEnum.createApiKey),
|
||||
apiKeys.createOrgApiKey
|
||||
@@ -962,6 +1055,7 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
`/org/:orgId/domain`,
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.createOrgDomain),
|
||||
logActionAudit(ActionsEnum.createOrgDomain),
|
||||
domain.createOrgDomain
|
||||
@@ -971,6 +1065,7 @@ authenticated.post(
|
||||
`/org/:orgId/domain/:domainId/restart`,
|
||||
verifyOrgAccess,
|
||||
verifyDomainAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.restartOrgDomain),
|
||||
logActionAudit(ActionsEnum.restartOrgDomain),
|
||||
domain.restartOrgDomain
|
||||
@@ -1017,6 +1112,7 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
"/org/:orgId/blueprint",
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.applyBlueprint),
|
||||
blueprints.applyYAMLBlueprint
|
||||
);
|
||||
@@ -1028,6 +1124,8 @@ authenticated.get(
|
||||
blueprints.getBlueprint
|
||||
);
|
||||
|
||||
authenticated.get("/ws/round-trip-message/:messageId", checkRoundTripMessage);
|
||||
|
||||
// Auth routes
|
||||
export const authRouter = Router();
|
||||
unauthenticated.use("/auth", authRouter);
|
||||
@@ -1076,6 +1174,22 @@ authRouter.post(
|
||||
auth.login
|
||||
);
|
||||
authRouter.post("/logout", auth.logout);
|
||||
authRouter.post("/delete-my-account", auth.deleteMyAccount);
|
||||
authRouter.post(
|
||||
"/lookup-user",
|
||||
rateLimit({
|
||||
windowMs: 15 * 60 * 1000,
|
||||
max: 15,
|
||||
keyGenerator: (req) =>
|
||||
`lookupUser:${req.body.identifier || ipKeyGenerator(req.ip || "")}`,
|
||||
handler: (req, res, next) => {
|
||||
const message = `You can only lookup users ${15} times every ${15} minutes. Please try again later.`;
|
||||
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
auth.lookupUser
|
||||
);
|
||||
authRouter.post(
|
||||
"/newt/get-token",
|
||||
rateLimit({
|
||||
@@ -1091,6 +1205,22 @@ authRouter.post(
|
||||
}),
|
||||
newt.getNewtToken
|
||||
);
|
||||
|
||||
authRouter.post(
|
||||
"/newt/register",
|
||||
rateLimit({
|
||||
windowMs: 15 * 60 * 1000,
|
||||
max: 30,
|
||||
keyGenerator: (req) =>
|
||||
`newtRegister:${req.body.provisioningKey?.split(".")[0] || ipKeyGenerator(req.ip || "")}`,
|
||||
handler: (req, res, next) => {
|
||||
const message = `You can only register a newt ${30} times every ${15} minutes. Please try again later.`;
|
||||
return next(createHttpError(HttpCode.TOO_MANY_REQUESTS, message));
|
||||
},
|
||||
store: createStore()
|
||||
}),
|
||||
newt.registerNewt
|
||||
);
|
||||
authRouter.post(
|
||||
"/olm/get-token",
|
||||
rateLimit({
|
||||
|
||||
@@ -6,6 +6,8 @@ export type GeneratedLicenseKey = {
|
||||
createdAt: string;
|
||||
tier: string;
|
||||
type: string;
|
||||
users: number;
|
||||
sites: number;
|
||||
};
|
||||
|
||||
export type ListGeneratedLicenseKeysResponse = GeneratedLicenseKey[];
|
||||
@@ -19,6 +21,7 @@ export type NewLicenseKey = {
|
||||
tier: string;
|
||||
type: string;
|
||||
quantity: number;
|
||||
quantity_2: number;
|
||||
isValid: boolean;
|
||||
updatedAt: string;
|
||||
createdAt: string;
|
||||
|
||||
@@ -125,7 +125,7 @@ export async function generateRelayMappings(exitNode: ExitNode) {
|
||||
// Add site as a destination for this client
|
||||
const destination: PeerDestination = {
|
||||
destinationIP: site.subnet.split("/")[0],
|
||||
destinationPort: site.listenPort
|
||||
destinationPort: site.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
|
||||
};
|
||||
|
||||
// Check if this destination is already in the array to avoid duplicates
|
||||
@@ -165,7 +165,7 @@ export async function generateRelayMappings(exitNode: ExitNode) {
|
||||
|
||||
const destination: PeerDestination = {
|
||||
destinationIP: peer.subnet.split("/")[0],
|
||||
destinationPort: peer.listenPort
|
||||
destinationPort: peer.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
|
||||
};
|
||||
|
||||
// Check for duplicates
|
||||
|
||||
@@ -51,7 +51,10 @@ export async function getConfig(
|
||||
);
|
||||
}
|
||||
|
||||
const exitNode = await createExitNode(publicKey, reachableAt);
|
||||
// clean up the public key - keep only valid base64 characters (A-Z, a-z, 0-9, +, /, =)
|
||||
const cleanedPublicKey = publicKey.replace(/[^A-Za-z0-9+/=]/g, "");
|
||||
|
||||
const exitNode = await createExitNode(cleanedPublicKey, reachableAt);
|
||||
|
||||
if (!exitNode) {
|
||||
return next(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { eq, and, lt, inArray, sql } from "drizzle-orm";
|
||||
import { sites } from "@server/db";
|
||||
import { sql } from "drizzle-orm";
|
||||
import { db } from "@server/db";
|
||||
import logger from "@server/logger";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -11,19 +10,34 @@ import { FeatureId } from "@server/lib/billing/features";
|
||||
import { checkExitNodeOrg } from "#dynamic/lib/exitNodes";
|
||||
import { build } from "@server/build";
|
||||
|
||||
// Track sites that are already offline to avoid unnecessary queries
|
||||
const offlineSites = new Set<string>();
|
||||
|
||||
// Retry configuration for deadlock handling
|
||||
const MAX_RETRIES = 3;
|
||||
const BASE_DELAY_MS = 50;
|
||||
|
||||
interface PeerBandwidth {
|
||||
publicKey: string;
|
||||
bytesIn: number;
|
||||
bytesOut: number;
|
||||
}
|
||||
|
||||
interface AccumulatorEntry {
|
||||
bytesIn: number;
|
||||
bytesOut: number;
|
||||
/** Present when the update came through a remote exit node. */
|
||||
exitNodeId?: number;
|
||||
/** Whether to record egress usage for billing purposes. */
|
||||
calcUsage: boolean;
|
||||
}
|
||||
|
||||
// Retry configuration for deadlock handling
|
||||
const MAX_RETRIES = 3;
|
||||
const BASE_DELAY_MS = 50;
|
||||
|
||||
// How often to flush accumulated bandwidth data to the database
|
||||
const FLUSH_INTERVAL_MS = 300_000; // 300 seconds
|
||||
|
||||
// Maximum number of sites to include in a single batch UPDATE statement
|
||||
const BATCH_CHUNK_SIZE = 250;
|
||||
|
||||
// In-memory accumulator: publicKey -> AccumulatorEntry
|
||||
let accumulator = new Map<string, AccumulatorEntry>();
|
||||
|
||||
/**
|
||||
* Check if an error is a deadlock error
|
||||
*/
|
||||
@@ -63,6 +77,266 @@ async function withDeadlockRetry<T>(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a raw SQL query that returns rows, in a way that works across both
|
||||
* the PostgreSQL driver (which exposes `execute`) and the SQLite driver (which
|
||||
* exposes `all`). Drizzle's typed query builder doesn't support bulk
|
||||
* UPDATE … FROM (VALUES …) natively, so we drop to raw SQL here.
|
||||
*/
|
||||
async function dbQueryRows<T extends Record<string, unknown>>(
|
||||
query: Parameters<(typeof sql)["join"]>[0][number]
|
||||
): Promise<T[]> {
|
||||
const anyDb = db as any;
|
||||
if (typeof anyDb.execute === "function") {
|
||||
// PostgreSQL (node-postgres via Drizzle) — returns { rows: [...] } or an array
|
||||
const result = await anyDb.execute(query);
|
||||
return (Array.isArray(result) ? result : (result.rows ?? [])) as T[];
|
||||
}
|
||||
// SQLite (better-sqlite3 via Drizzle) — returns an array directly
|
||||
return (await anyDb.all(query)) as T[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true when the active database driver is SQLite (better-sqlite3).
|
||||
* Used to select the appropriate bulk-update strategy.
|
||||
*/
|
||||
function isSQLite(): boolean {
|
||||
return typeof (db as any).execute !== "function";
|
||||
}
|
||||
|
||||
/**
|
||||
* Flush all accumulated site bandwidth data to the database.
|
||||
*
|
||||
* Swaps out the accumulator before writing so that any bandwidth messages
|
||||
* received during the flush are captured in the new accumulator rather than
|
||||
* being lost or causing contention. Sites are updated in chunks via a single
|
||||
* batch UPDATE per chunk. Failed chunks are discarded — exact per-flush
|
||||
* accuracy is not critical and re-queuing is not worth the added complexity.
|
||||
*
|
||||
* This function is exported so that the application's graceful-shutdown
|
||||
* cleanup handler can call it before the process exits.
|
||||
*/
|
||||
export async function flushSiteBandwidthToDb(): Promise<void> {
|
||||
if (accumulator.size === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Atomically swap out the accumulator so new data keeps flowing in
|
||||
// while we write the snapshot to the database.
|
||||
const snapshot = accumulator;
|
||||
accumulator = new Map<string, AccumulatorEntry>();
|
||||
|
||||
const currentTime = new Date().toISOString();
|
||||
|
||||
// Sort by publicKey for consistent lock ordering across concurrent
|
||||
// writers — deadlock-prevention strategy.
|
||||
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
|
||||
a.localeCompare(b)
|
||||
);
|
||||
|
||||
logger.debug(
|
||||
`Flushing accumulated bandwidth data for ${sortedEntries.length} site(s) to the database`
|
||||
);
|
||||
|
||||
// Build a lookup so post-processing can reach each entry by publicKey.
|
||||
const snapshotMap = new Map(sortedEntries);
|
||||
|
||||
// Aggregate billing usage by org across all chunks.
|
||||
const orgUsageMap = new Map<string, number>();
|
||||
|
||||
// Process in chunks so individual queries stay at a reasonable size.
|
||||
for (let i = 0; i < sortedEntries.length; i += BATCH_CHUNK_SIZE) {
|
||||
const chunk = sortedEntries.slice(i, i + BATCH_CHUNK_SIZE);
|
||||
const chunkEnd = i + chunk.length - 1;
|
||||
|
||||
let rows: { orgId: string; pubKey: string }[] = [];
|
||||
|
||||
try {
|
||||
rows = await withDeadlockRetry(async () => {
|
||||
if (isSQLite()) {
|
||||
// SQLite: one UPDATE per row — no need for batch efficiency here.
|
||||
const results: { orgId: string; pubKey: string }[] = [];
|
||||
for (const [publicKey, { bytesIn, bytesOut }] of chunk) {
|
||||
const result = await dbQueryRows<{
|
||||
orgId: string;
|
||||
pubKey: string;
|
||||
}>(sql`
|
||||
UPDATE sites
|
||||
SET
|
||||
"bytesOut" = COALESCE("bytesOut", 0) + ${bytesIn},
|
||||
"bytesIn" = COALESCE("bytesIn", 0) + ${bytesOut},
|
||||
"lastBandwidthUpdate" = ${currentTime}
|
||||
WHERE "pubKey" = ${publicKey}
|
||||
RETURNING "orgId", "pubKey"
|
||||
`);
|
||||
results.push(...result);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
// PostgreSQL: batch UPDATE … FROM (VALUES …) — single round-trip per chunk.
|
||||
const valuesList = chunk.map(
|
||||
([publicKey, { bytesIn, bytesOut }]) =>
|
||||
sql`(${publicKey}, ${bytesIn}, ${bytesOut})`
|
||||
);
|
||||
const valuesClause = sql.join(valuesList, sql`, `);
|
||||
return dbQueryRows<{ orgId: string; pubKey: string }>(sql`
|
||||
UPDATE sites
|
||||
SET
|
||||
"bytesOut" = COALESCE("bytesOut", 0) + v.bytes_in,
|
||||
"bytesIn" = COALESCE("bytesIn", 0) + v.bytes_out,
|
||||
"lastBandwidthUpdate" = ${currentTime}
|
||||
FROM (VALUES ${valuesClause}) AS v(pub_key, bytes_in, bytes_out)
|
||||
WHERE sites."pubKey" = v.pub_key
|
||||
RETURNING sites."orgId" AS "orgId", sites."pubKey" AS "pubKey"
|
||||
`);
|
||||
}, `flush bandwidth chunk [${i}–${chunkEnd}]`);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to flush bandwidth chunk [${i}–${chunkEnd}], discarding ${chunk.length} site(s):`,
|
||||
error
|
||||
);
|
||||
// Discard the chunk — exact per-flush accuracy is not critical.
|
||||
continue;
|
||||
}
|
||||
|
||||
// Collect billing usage from the returned rows.
|
||||
for (const { orgId, pubKey } of rows) {
|
||||
const entry = snapshotMap.get(pubKey);
|
||||
if (!entry) continue;
|
||||
|
||||
const { bytesIn, bytesOut, exitNodeId, calcUsage } = entry;
|
||||
|
||||
if (exitNodeId) {
|
||||
const notAllowed = await checkExitNodeOrg(exitNodeId, orgId);
|
||||
if (notAllowed) {
|
||||
logger.warn(
|
||||
`Exit node ${exitNodeId} is not allowed for org ${orgId}`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (calcUsage) {
|
||||
const current = orgUsageMap.get(orgId) ?? 0;
|
||||
orgUsageMap.set(orgId, current + bytesIn + bytesOut);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process billing usage updates after all chunks are written.
|
||||
if (orgUsageMap.size > 0) {
|
||||
const sortedOrgIds = [...orgUsageMap.keys()].sort();
|
||||
|
||||
for (const orgId of sortedOrgIds) {
|
||||
try {
|
||||
const totalBandwidth = orgUsageMap.get(orgId)!;
|
||||
const bandwidthUsage = await usageService.add(
|
||||
orgId,
|
||||
FeatureId.EGRESS_DATA_MB,
|
||||
totalBandwidth
|
||||
);
|
||||
if (bandwidthUsage) {
|
||||
// Fire-and-forget — don't block the flush on limit checking.
|
||||
usageService
|
||||
.checkLimitSet(
|
||||
orgId,
|
||||
FeatureId.EGRESS_DATA_MB,
|
||||
bandwidthUsage
|
||||
)
|
||||
.catch((error: any) => {
|
||||
logger.error(
|
||||
`Error checking bandwidth limits for org ${orgId}:`,
|
||||
error
|
||||
);
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error processing usage for org ${orgId}:`,
|
||||
error
|
||||
);
|
||||
// Continue with other orgs.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Periodic flush timer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const flushTimer = setInterval(async () => {
|
||||
try {
|
||||
await flushSiteBandwidthToDb();
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
"Unexpected error during periodic site bandwidth flush:",
|
||||
error
|
||||
);
|
||||
}
|
||||
}, FLUSH_INTERVAL_MS);
|
||||
|
||||
// Allow the process to exit normally even while the timer is pending.
|
||||
// The graceful-shutdown path (see server/cleanup.ts) will call
|
||||
// flushSiteBandwidthToDb() explicitly before process.exit(), so no data
|
||||
// is lost.
|
||||
flushTimer.unref();
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public API
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Accumulate bandwidth data reported by a gerbil or remote exit node.
|
||||
*
|
||||
* Only peers that actually transferred data (bytesIn > 0) are added to the
|
||||
* accumulator; peers with no activity are silently ignored, which means the
|
||||
* flush will only write rows that have genuinely changed.
|
||||
*
|
||||
* The function is intentionally synchronous in its fast path so that the
|
||||
* HTTP handler can respond immediately without waiting for any I/O.
|
||||
*/
|
||||
export async function updateSiteBandwidth(
|
||||
bandwidthData: PeerBandwidth[],
|
||||
calcUsageAndLimits: boolean,
|
||||
exitNodeId?: number
|
||||
): Promise<void> {
|
||||
for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
|
||||
// Skip peers that haven't transferred any data — writing zeros to the
|
||||
// database would be a no-op anyway.
|
||||
if (bytesIn <= 0 && bytesOut <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const existing = accumulator.get(publicKey);
|
||||
if (existing) {
|
||||
existing.bytesIn += bytesIn;
|
||||
existing.bytesOut += bytesOut;
|
||||
// Retain the most-recent exitNodeId for this peer.
|
||||
if (exitNodeId !== undefined) {
|
||||
existing.exitNodeId = exitNodeId;
|
||||
}
|
||||
// Once calcUsage has been requested for a peer, keep it set for
|
||||
// the lifetime of this flush window.
|
||||
if (calcUsageAndLimits) {
|
||||
existing.calcUsage = true;
|
||||
}
|
||||
} else {
|
||||
accumulator.set(publicKey, {
|
||||
bytesIn,
|
||||
bytesOut,
|
||||
exitNodeId,
|
||||
calcUsage: calcUsageAndLimits
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HTTP handler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const receiveBandwidth = async (
|
||||
req: Request,
|
||||
res: Response,
|
||||
@@ -75,7 +349,9 @@ export const receiveBandwidth = async (
|
||||
throw new Error("Invalid bandwidth data");
|
||||
}
|
||||
|
||||
await updateSiteBandwidth(bandwidthData, build == "saas"); // we are checking the usage on saas only
|
||||
// Accumulate in memory; the periodic timer (and the shutdown hook)
|
||||
// will write to the database.
|
||||
await updateSiteBandwidth(bandwidthData, build == "saas");
|
||||
|
||||
return response(res, {
|
||||
data: {},
|
||||
@@ -94,239 +370,3 @@ export const receiveBandwidth = async (
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
export async function updateSiteBandwidth(
|
||||
bandwidthData: PeerBandwidth[],
|
||||
calcUsageAndLimits: boolean,
|
||||
exitNodeId?: number
|
||||
) {
|
||||
const currentTime = new Date();
|
||||
const oneMinuteAgo = new Date(currentTime.getTime() - 60000); // 1 minute ago
|
||||
|
||||
// Sort bandwidth data by publicKey to ensure consistent lock ordering across all instances
|
||||
// This is critical for preventing deadlocks when multiple instances update the same sites
|
||||
const sortedBandwidthData = [...bandwidthData].sort((a, b) =>
|
||||
a.publicKey.localeCompare(b.publicKey)
|
||||
);
|
||||
|
||||
// First, handle sites that are actively reporting bandwidth
|
||||
const activePeers = sortedBandwidthData.filter((peer) => peer.bytesIn > 0);
|
||||
|
||||
// Aggregate usage data by organization (collected outside transaction)
|
||||
const orgUsageMap = new Map<string, number>();
|
||||
const orgUptimeMap = new Map<string, number>();
|
||||
|
||||
if (activePeers.length > 0) {
|
||||
// Remove any active peers from offline tracking since they're sending data
|
||||
activePeers.forEach((peer) => offlineSites.delete(peer.publicKey));
|
||||
|
||||
// Update each active site individually with retry logic
|
||||
// This reduces transaction scope and allows retries per-site
|
||||
for (const peer of activePeers) {
|
||||
try {
|
||||
const updatedSite = await withDeadlockRetry(async () => {
|
||||
const [result] = await db
|
||||
.update(sites)
|
||||
.set({
|
||||
megabytesOut: sql`${sites.megabytesOut} + ${peer.bytesIn}`,
|
||||
megabytesIn: sql`${sites.megabytesIn} + ${peer.bytesOut}`,
|
||||
lastBandwidthUpdate: currentTime.toISOString(),
|
||||
online: true
|
||||
})
|
||||
.where(eq(sites.pubKey, peer.publicKey))
|
||||
.returning({
|
||||
online: sites.online,
|
||||
orgId: sites.orgId,
|
||||
siteId: sites.siteId,
|
||||
lastBandwidthUpdate: sites.lastBandwidthUpdate
|
||||
});
|
||||
return result;
|
||||
}, `update active site ${peer.publicKey}`);
|
||||
|
||||
if (updatedSite) {
|
||||
if (exitNodeId) {
|
||||
const notAllowed = await checkExitNodeOrg(
|
||||
exitNodeId,
|
||||
updatedSite.orgId
|
||||
);
|
||||
if (notAllowed) {
|
||||
logger.warn(
|
||||
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
|
||||
);
|
||||
// Skip this site but continue processing others
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Aggregate bandwidth usage for the org
|
||||
const totalBandwidth = peer.bytesIn + peer.bytesOut;
|
||||
const currentOrgUsage =
|
||||
orgUsageMap.get(updatedSite.orgId) || 0;
|
||||
orgUsageMap.set(
|
||||
updatedSite.orgId,
|
||||
currentOrgUsage + totalBandwidth
|
||||
);
|
||||
|
||||
// Add 10 seconds of uptime for each active site
|
||||
const currentOrgUptime =
|
||||
orgUptimeMap.get(updatedSite.orgId) || 0;
|
||||
orgUptimeMap.set(
|
||||
updatedSite.orgId,
|
||||
currentOrgUptime + 10 / 60
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to update bandwidth for site ${peer.publicKey}:`,
|
||||
error
|
||||
);
|
||||
// Continue with other sites
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process usage updates outside of site update transactions
|
||||
// This separates the concerns and reduces lock contention
|
||||
if (calcUsageAndLimits && (orgUsageMap.size > 0 || orgUptimeMap.size > 0)) {
|
||||
// Sort org IDs to ensure consistent lock ordering
|
||||
const allOrgIds = [
|
||||
...new Set([...orgUsageMap.keys(), ...orgUptimeMap.keys()])
|
||||
].sort();
|
||||
|
||||
for (const orgId of allOrgIds) {
|
||||
try {
|
||||
// Process bandwidth usage for this org
|
||||
const totalBandwidth = orgUsageMap.get(orgId);
|
||||
if (totalBandwidth) {
|
||||
const bandwidthUsage = await usageService.add(
|
||||
orgId,
|
||||
FeatureId.EGRESS_DATA_MB,
|
||||
totalBandwidth
|
||||
);
|
||||
if (bandwidthUsage) {
|
||||
// Fire and forget - don't block on limit checking
|
||||
usageService
|
||||
.checkLimitSet(
|
||||
orgId,
|
||||
true,
|
||||
FeatureId.EGRESS_DATA_MB,
|
||||
bandwidthUsage
|
||||
)
|
||||
.catch((error: any) => {
|
||||
logger.error(
|
||||
`Error checking bandwidth limits for org ${orgId}:`,
|
||||
error
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Process uptime usage for this org
|
||||
const totalUptime = orgUptimeMap.get(orgId);
|
||||
if (totalUptime) {
|
||||
const uptimeUsage = await usageService.add(
|
||||
orgId,
|
||||
FeatureId.SITE_UPTIME,
|
||||
totalUptime
|
||||
);
|
||||
if (uptimeUsage) {
|
||||
// Fire and forget - don't block on limit checking
|
||||
usageService
|
||||
.checkLimitSet(
|
||||
orgId,
|
||||
true,
|
||||
FeatureId.SITE_UPTIME,
|
||||
uptimeUsage
|
||||
)
|
||||
.catch((error: any) => {
|
||||
logger.error(
|
||||
`Error checking uptime limits for org ${orgId}:`,
|
||||
error
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Error processing usage for org ${orgId}:`, error);
|
||||
// Continue with other orgs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle sites that reported zero bandwidth but need online status updated
|
||||
const zeroBandwidthPeers = sortedBandwidthData.filter(
|
||||
(peer) => peer.bytesIn === 0 && !offlineSites.has(peer.publicKey)
|
||||
);
|
||||
|
||||
if (zeroBandwidthPeers.length > 0) {
|
||||
// Fetch all zero bandwidth sites in one query
|
||||
const zeroBandwidthSites = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(
|
||||
inArray(
|
||||
sites.pubKey,
|
||||
zeroBandwidthPeers.map((p) => p.publicKey)
|
||||
)
|
||||
);
|
||||
|
||||
// Sort by siteId to ensure consistent lock ordering
|
||||
const sortedZeroBandwidthSites = zeroBandwidthSites.sort(
|
||||
(a, b) => a.siteId - b.siteId
|
||||
);
|
||||
|
||||
for (const site of sortedZeroBandwidthSites) {
|
||||
let newOnlineStatus = site.online;
|
||||
|
||||
// Check if site should go offline based on last bandwidth update WITH DATA
|
||||
if (site.lastBandwidthUpdate) {
|
||||
const lastUpdateWithData = new Date(site.lastBandwidthUpdate);
|
||||
if (lastUpdateWithData < oneMinuteAgo) {
|
||||
newOnlineStatus = false;
|
||||
}
|
||||
} else {
|
||||
// No previous data update recorded, set to offline
|
||||
newOnlineStatus = false;
|
||||
}
|
||||
|
||||
// Only update online status if it changed
|
||||
if (site.online !== newOnlineStatus) {
|
||||
try {
|
||||
const updatedSite = await withDeadlockRetry(async () => {
|
||||
const [result] = await db
|
||||
.update(sites)
|
||||
.set({
|
||||
online: newOnlineStatus
|
||||
})
|
||||
.where(eq(sites.siteId, site.siteId))
|
||||
.returning();
|
||||
return result;
|
||||
}, `update offline status for site ${site.siteId}`);
|
||||
|
||||
if (updatedSite && exitNodeId) {
|
||||
const notAllowed = await checkExitNodeOrg(
|
||||
exitNodeId,
|
||||
updatedSite.orgId
|
||||
);
|
||||
if (notAllowed) {
|
||||
logger.warn(
|
||||
`Exit node ${exitNodeId} is not allowed for org ${updatedSite.orgId}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// If site went offline, add it to our tracking set
|
||||
if (!newOnlineStatus && site.pubKey) {
|
||||
offlineSites.add(site.pubKey);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to update offline status for site ${site.siteId}:`,
|
||||
error
|
||||
);
|
||||
// Continue with other sites
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ export async function updateHolePunch(
|
||||
destinations: destinations
|
||||
});
|
||||
} catch (error) {
|
||||
// logger.error(error); // FIX THIS
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
@@ -262,7 +262,7 @@ export async function updateAndGenerateEndpointDestinations(
|
||||
if (site.subnet && site.listenPort) {
|
||||
destinations.push({
|
||||
destinationIP: site.subnet.split("/")[0],
|
||||
destinationPort: site.listenPort
|
||||
destinationPort: site.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -339,10 +339,10 @@ export async function updateAndGenerateEndpointDestinations(
|
||||
handleSiteEndpointChange(newt.siteId, updatedSite.endpoint!);
|
||||
}
|
||||
|
||||
if (!updatedSite || !updatedSite.subnet) {
|
||||
logger.warn(`Site not found: ${newt.siteId}`);
|
||||
throw new Error("Site not found");
|
||||
}
|
||||
// if (!updatedSite || !updatedSite.subnet) {
|
||||
// logger.warn(`Site not found: ${newt.siteId}`);
|
||||
// throw new Error("Site not found");
|
||||
// }
|
||||
|
||||
// Find all clients that connect to this site
|
||||
// const sitesClientPairs = await db
|
||||
|
||||
@@ -27,7 +27,7 @@ registry.registerPath({
|
||||
method: "put",
|
||||
path: "/idp/{idpId}/org/{orgId}",
|
||||
description: "Create an IDP policy for an existing IDP on an organization.",
|
||||
tags: [OpenAPITags.Idp],
|
||||
tags: [OpenAPITags.GlobalIdp],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
body: {
|
||||
@@ -70,6 +70,15 @@ export async function createIdpOrgPolicy(
|
||||
const { idpId, orgId } = parsedParams.data;
|
||||
const { roleMapping, orgMapping } = parsedBody.data;
|
||||
|
||||
if (process.env.IDENTITY_PROVIDER_MODE === "org") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [existing] = await db
|
||||
.select()
|
||||
.from(idp)
|
||||
|
||||
@@ -24,7 +24,9 @@ const bodySchema = z.strictObject({
|
||||
emailPath: z.string().optional(),
|
||||
namePath: z.string().optional(),
|
||||
scopes: z.string().nonempty(),
|
||||
autoProvision: z.boolean().optional()
|
||||
autoProvision: z.boolean().optional(),
|
||||
tags: z.string().optional(),
|
||||
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc")
|
||||
});
|
||||
|
||||
export type CreateIdpResponse = {
|
||||
@@ -36,7 +38,7 @@ registry.registerPath({
|
||||
method: "put",
|
||||
path: "/idp/oidc",
|
||||
description: "Create an OIDC IdP.",
|
||||
tags: [OpenAPITags.Idp],
|
||||
tags: [OpenAPITags.GlobalIdp],
|
||||
request: {
|
||||
body: {
|
||||
content: {
|
||||
@@ -75,9 +77,22 @@ export async function createOidcIdp(
|
||||
emailPath,
|
||||
namePath,
|
||||
name,
|
||||
autoProvision
|
||||
autoProvision,
|
||||
tags,
|
||||
variant
|
||||
} = parsedBody.data;
|
||||
|
||||
if (
|
||||
process.env.IDENTITY_PROVIDER_MODE === "org"
|
||||
) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const key = config.getRawConfig().server.secret!;
|
||||
|
||||
const encryptedSecret = encrypt(clientSecret, key);
|
||||
@@ -90,7 +105,10 @@ export async function createOidcIdp(
|
||||
.values({
|
||||
name,
|
||||
autoProvision,
|
||||
type: "oidc"
|
||||
type: "oidc",
|
||||
tags,
|
||||
defaultOrgMapping: `'{{orgId}}'`,
|
||||
defaultRoleMapping: `'Member'`
|
||||
})
|
||||
.returning();
|
||||
|
||||
@@ -105,7 +123,8 @@ export async function createOidcIdp(
|
||||
scopes,
|
||||
identifierPath,
|
||||
emailPath,
|
||||
namePath
|
||||
namePath,
|
||||
variant
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ registry.registerPath({
|
||||
method: "delete",
|
||||
path: "/idp/{idpId}",
|
||||
description: "Delete IDP.",
|
||||
tags: [OpenAPITags.Idp],
|
||||
tags: [OpenAPITags.GlobalIdp],
|
||||
request: {
|
||||
params: paramsSchema
|
||||
},
|
||||
|
||||
@@ -19,7 +19,7 @@ registry.registerPath({
|
||||
method: "delete",
|
||||
path: "/idp/{idpId}/org/{orgId}",
|
||||
description: "Create an OIDC IdP for an organization.",
|
||||
tags: [OpenAPITags.Idp],
|
||||
tags: [OpenAPITags.GlobalIdp],
|
||||
request: {
|
||||
params: paramsSchema
|
||||
},
|
||||
|
||||
@@ -14,8 +14,8 @@ import jsonwebtoken from "jsonwebtoken";
|
||||
import config from "@server/lib/config";
|
||||
import { decrypt } from "@server/lib/crypto";
|
||||
import { build } from "@server/build";
|
||||
import { getOrgTierData } from "#dynamic/lib/billing";
|
||||
import { TierId } from "@server/lib/billing/tiers";
|
||||
import { isSubscribed } from "#dynamic/lib/isSubscribed";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
@@ -113,8 +113,10 @@ export async function generateOidcUrl(
|
||||
}
|
||||
|
||||
if (build === "saas") {
|
||||
const { tier } = await getOrgTierData(orgId);
|
||||
const subscribed = tier === TierId.STANDARD;
|
||||
const subscribed = await isSubscribed(
|
||||
orgId,
|
||||
tierMatrix.orgOidc
|
||||
);
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
|
||||
@@ -34,7 +34,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/idp/{idpId}",
|
||||
description: "Get an IDP by its IDP ID.",
|
||||
tags: [OpenAPITags.Idp],
|
||||
tags: [OpenAPITags.GlobalIdp],
|
||||
request: {
|
||||
params: paramsSchema
|
||||
},
|
||||
|
||||
@@ -48,7 +48,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/idp/{idpId}/org",
|
||||
description: "List all org policies on an IDP.",
|
||||
tags: [OpenAPITags.Idp],
|
||||
tags: [OpenAPITags.GlobalIdp],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
query: querySchema
|
||||
|
||||
@@ -33,7 +33,8 @@ async function query(limit: number, offset: number) {
|
||||
type: idp.type,
|
||||
variant: idpOidcConfig.variant,
|
||||
orgCount: sql<number>`count(${idpOrg.orgId})`,
|
||||
autoProvision: idp.autoProvision
|
||||
autoProvision: idp.autoProvision,
|
||||
tags: idp.tags
|
||||
})
|
||||
.from(idp)
|
||||
.leftJoin(idpOrg, sql`${idp.idpId} = ${idpOrg.idpId}`)
|
||||
@@ -57,7 +58,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/idp",
|
||||
description: "List all IDP in the system.",
|
||||
tags: [OpenAPITags.Idp],
|
||||
tags: [OpenAPITags.GlobalIdp],
|
||||
request: {
|
||||
query: querySchema
|
||||
},
|
||||
|
||||
@@ -26,7 +26,7 @@ registry.registerPath({
|
||||
method: "post",
|
||||
path: "/idp/{idpId}/org/{orgId}",
|
||||
description: "Update an IDP org policy.",
|
||||
tags: [OpenAPITags.Idp],
|
||||
tags: [OpenAPITags.GlobalIdp],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
body: {
|
||||
@@ -69,6 +69,15 @@ export async function updateIdpOrgPolicy(
|
||||
const { idpId, orgId } = parsedParams.data;
|
||||
const { roleMapping, orgMapping } = parsedBody.data;
|
||||
|
||||
if (process.env.IDENTITY_PROVIDER_MODE === "org") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Check if IDP and policy exist
|
||||
const [existing] = await db
|
||||
.select()
|
||||
|
||||
@@ -30,7 +30,9 @@ const bodySchema = z.strictObject({
|
||||
scopes: z.string().optional(),
|
||||
autoProvision: z.boolean().optional(),
|
||||
defaultRoleMapping: z.string().optional(),
|
||||
defaultOrgMapping: z.string().optional()
|
||||
defaultOrgMapping: z.string().optional(),
|
||||
tags: z.string().optional(),
|
||||
variant: z.enum(["oidc", "google", "azure"]).optional()
|
||||
});
|
||||
|
||||
export type UpdateIdpResponse = {
|
||||
@@ -41,7 +43,7 @@ registry.registerPath({
|
||||
method: "post",
|
||||
path: "/idp/{idpId}/oidc",
|
||||
description: "Update an OIDC IdP.",
|
||||
tags: [OpenAPITags.Idp],
|
||||
tags: [OpenAPITags.GlobalIdp],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
body: {
|
||||
@@ -94,9 +96,20 @@ export async function updateOidcIdp(
|
||||
name,
|
||||
autoProvision,
|
||||
defaultRoleMapping,
|
||||
defaultOrgMapping
|
||||
defaultOrgMapping,
|
||||
tags,
|
||||
variant
|
||||
} = parsedBody.data;
|
||||
|
||||
if (process.env.IDENTITY_PROVIDER_MODE === "org") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Global IdP creation is not allowed in the current identity provider mode. Set app.identity_provider_mode to 'global' in the private configuration to enable this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Check if IDP exists and is of type OIDC
|
||||
const [existingIdp] = await db
|
||||
.select()
|
||||
@@ -127,7 +140,8 @@ export async function updateOidcIdp(
|
||||
name,
|
||||
autoProvision,
|
||||
defaultRoleMapping,
|
||||
defaultOrgMapping
|
||||
defaultOrgMapping,
|
||||
tags
|
||||
};
|
||||
|
||||
// only update if at least one key is not undefined
|
||||
@@ -147,7 +161,8 @@ export async function updateOidcIdp(
|
||||
scopes,
|
||||
identifierPath,
|
||||
emailPath,
|
||||
namePath
|
||||
namePath,
|
||||
variant
|
||||
};
|
||||
|
||||
keysToUpdate = Object.keys(configData).filter(
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
orgs,
|
||||
Role,
|
||||
roles,
|
||||
userOrgRoles,
|
||||
userOrgs,
|
||||
users
|
||||
} from "@server/db";
|
||||
@@ -34,6 +35,14 @@ import { FeatureId } from "@server/lib/billing";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import { build } from "@server/build";
|
||||
import { calculateUserClientsForOrgs } from "@server/lib/calculateUserClientsForOrgs";
|
||||
import { isSubscribed } from "#dynamic/lib/isSubscribed";
|
||||
import { isLicensedOrSubscribed } from "#dynamic/lib/isLicencedOrSubscribed";
|
||||
import { tierMatrix } from "@server/lib/billing/tierMatrix";
|
||||
import {
|
||||
assignUserToOrg,
|
||||
removeUserFromOrg
|
||||
} from "@server/lib/userOrg";
|
||||
import { unwrapRoleMapping } from "@app/lib/idpRoleMapping";
|
||||
|
||||
const ensureTrailingSlash = (url: string): string => {
|
||||
return url;
|
||||
@@ -192,11 +201,71 @@ export async function validateOidcCallback(
|
||||
state
|
||||
});
|
||||
|
||||
const tokens = await client.validateAuthorizationCode(
|
||||
ensureTrailingSlash(existingIdp.idpOidcConfig.tokenUrl),
|
||||
code,
|
||||
codeVerifier
|
||||
);
|
||||
let tokens: arctic.OAuth2Tokens;
|
||||
try {
|
||||
tokens = await client.validateAuthorizationCode(
|
||||
ensureTrailingSlash(existingIdp.idpOidcConfig.tokenUrl),
|
||||
code,
|
||||
codeVerifier
|
||||
);
|
||||
} catch (err: unknown) {
|
||||
if (err instanceof arctic.OAuth2RequestError) {
|
||||
logger.warn("OIDC provider rejected the authorization code", {
|
||||
error: err.code,
|
||||
description: err.description,
|
||||
uri: err.uri,
|
||||
state: err.state
|
||||
});
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
err.description ||
|
||||
`OIDC provider rejected the request (${err.code})`
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (err instanceof arctic.UnexpectedResponseError) {
|
||||
logger.error(
|
||||
"OIDC provider returned an unexpected response during token exchange",
|
||||
{ status: err.status }
|
||||
);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_GATEWAY,
|
||||
"Received an unexpected response from the identity provider while exchanging the authorization code."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (err instanceof arctic.UnexpectedErrorResponseBodyError) {
|
||||
logger.error(
|
||||
"OIDC provider returned an unexpected error payload during token exchange",
|
||||
{ status: err.status, data: err.data }
|
||||
);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_GATEWAY,
|
||||
"Identity provider returned an unexpected error payload while exchanging the authorization code."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (err instanceof arctic.ArcticFetchError) {
|
||||
logger.error(
|
||||
"Failed to reach OIDC provider while exchanging authorization code",
|
||||
{ error: err.message }
|
||||
);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_GATEWAY,
|
||||
"Unable to reach the identity provider while exchanging the authorization code. Please try again."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
throw err;
|
||||
}
|
||||
|
||||
const idToken = tokens.idToken();
|
||||
logger.debug("ID token", { idToken });
|
||||
@@ -266,6 +335,33 @@ export async function validateOidcCallback(
|
||||
.where(eq(idpOrg.idpId, existingIdp.idp.idpId))
|
||||
.innerJoin(orgs, eq(orgs.orgId, idpOrg.orgId));
|
||||
allOrgs = idpOrgs.map((o) => o.orgs);
|
||||
|
||||
// TODO: when there are multiple orgs we need to do this better!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1
|
||||
if (allOrgs.length > 1) {
|
||||
// for some reason there is more than one org
|
||||
logger.error(
|
||||
"More than one organization linked to this IdP. This should not happen with auto-provisioning enabled."
|
||||
);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Multiple organizations linked to this IdP. Please contact support."
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const subscribed = await isSubscribed(
|
||||
allOrgs[0].orgId,
|
||||
tierMatrix.autoProvisioning
|
||||
);
|
||||
if (!subscribed) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"This organization's current plan does not support this feature."
|
||||
)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
allOrgs = await db.select().from(orgs);
|
||||
}
|
||||
@@ -273,7 +369,7 @@ export async function validateOidcCallback(
|
||||
const defaultRoleMapping = existingIdp.idp.defaultRoleMapping;
|
||||
const defaultOrgMapping = existingIdp.idp.defaultOrgMapping;
|
||||
|
||||
const userOrgInfo: { orgId: string; roleId: number }[] = [];
|
||||
const userOrgInfo: { orgId: string; roleIds: number[] }[] = [];
|
||||
for (const org of allOrgs) {
|
||||
const [idpOrgRes] = await db
|
||||
.select()
|
||||
@@ -285,8 +381,6 @@ export async function validateOidcCallback(
|
||||
)
|
||||
);
|
||||
|
||||
let roleId: number | undefined = undefined;
|
||||
|
||||
const orgMapping = idpOrgRes?.orgMapping || defaultOrgMapping;
|
||||
const hydratedOrgMapping = hydrateOrgMapping(
|
||||
orgMapping,
|
||||
@@ -311,42 +405,60 @@ export async function validateOidcCallback(
|
||||
idpOrgRes?.roleMapping || defaultRoleMapping;
|
||||
if (roleMapping) {
|
||||
logger.debug("Role Mapping", { roleMapping });
|
||||
const roleName = jmespath.search(claims, roleMapping);
|
||||
const roleMappingJmes = unwrapRoleMapping(
|
||||
roleMapping
|
||||
).evaluationExpression;
|
||||
const roleMappingResult = jmespath.search(
|
||||
claims,
|
||||
roleMappingJmes
|
||||
);
|
||||
const roleNames = normalizeRoleMappingResult(
|
||||
roleMappingResult
|
||||
);
|
||||
|
||||
if (!roleName) {
|
||||
logger.error("Role name not found in the ID token", {
|
||||
roleName
|
||||
const supportsMultiRole = await isLicensedOrSubscribed(
|
||||
org.orgId,
|
||||
tierMatrix.fullRbac
|
||||
);
|
||||
const effectiveRoleNames = supportsMultiRole
|
||||
? roleNames
|
||||
: roleNames.slice(0, 1);
|
||||
|
||||
if (!effectiveRoleNames.length) {
|
||||
logger.error("Role mapping returned no valid roles", {
|
||||
roleMappingResult
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
const [roleRes] = await db
|
||||
const roleRes = await db
|
||||
.select()
|
||||
.from(roles)
|
||||
.where(
|
||||
and(
|
||||
eq(roles.orgId, org.orgId),
|
||||
eq(roles.name, roleName)
|
||||
inArray(roles.name, effectiveRoleNames)
|
||||
)
|
||||
);
|
||||
|
||||
if (!roleRes) {
|
||||
logger.error("Role not found", {
|
||||
if (!roleRes.length) {
|
||||
logger.error("No mapped roles found in organization", {
|
||||
orgId: org.orgId,
|
||||
roleName
|
||||
roleNames: effectiveRoleNames
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
roleId = roleRes.roleId;
|
||||
const roleIds = [...new Set(roleRes.map((r) => r.roleId))];
|
||||
|
||||
userOrgInfo.push({
|
||||
orgId: org.orgId,
|
||||
roleId
|
||||
roleIds
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// These are the orgs that the user should be provisioned into based on the IdP mappings and the token claims
|
||||
logger.debug("User org info", { userOrgInfo });
|
||||
|
||||
let existingUserId = existingUser?.userId;
|
||||
@@ -365,15 +477,32 @@ export async function validateOidcCallback(
|
||||
);
|
||||
|
||||
if (!existingUserOrgs.length) {
|
||||
// delete all auto -provisioned user orgs
|
||||
await db
|
||||
.delete(userOrgs)
|
||||
// delete all auto-provisioned user orgs
|
||||
const autoProvisionedUserOrgs = await db
|
||||
.select()
|
||||
.from(userOrgs)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.userId, existingUser.userId),
|
||||
eq(userOrgs.autoProvisioned, true)
|
||||
)
|
||||
);
|
||||
const orgIdsToRemove = autoProvisionedUserOrgs.map(
|
||||
(uo) => uo.orgId
|
||||
);
|
||||
if (orgIdsToRemove.length > 0) {
|
||||
const orgsToRemove = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(inArray(orgs.orgId, orgIdsToRemove));
|
||||
for (const org of orgsToRemove) {
|
||||
await removeUserFromOrg(
|
||||
org,
|
||||
existingUser.userId,
|
||||
db
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
await calculateUserClientsForOrgs(existingUser.userId);
|
||||
|
||||
@@ -395,7 +524,7 @@ export async function validateOidcCallback(
|
||||
}
|
||||
}
|
||||
|
||||
const orgUserCounts: { orgId: string; userCount: number }[] = [];
|
||||
const orgUserCounts: { orgId: string; userCount: number }[] = [];
|
||||
|
||||
// sync the user with the orgs and roles
|
||||
await db.transaction(async (trx) => {
|
||||
@@ -449,43 +578,38 @@ export async function validateOidcCallback(
|
||||
);
|
||||
|
||||
if (orgsToDelete.length > 0) {
|
||||
await trx.delete(userOrgs).where(
|
||||
and(
|
||||
eq(userOrgs.userId, userId!),
|
||||
inArray(
|
||||
userOrgs.orgId,
|
||||
orgsToDelete.map((org) => org.orgId)
|
||||
)
|
||||
)
|
||||
);
|
||||
const orgIdsToRemove = orgsToDelete.map((org) => org.orgId);
|
||||
const fullOrgsToRemove = await trx
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(inArray(orgs.orgId, orgIdsToRemove));
|
||||
for (const org of fullOrgsToRemove) {
|
||||
await removeUserFromOrg(org, userId!, trx);
|
||||
}
|
||||
}
|
||||
|
||||
// Update roles for existing auto-provisioned orgs where the role has changed
|
||||
const orgsToUpdate = autoProvisionedOrgs.filter(
|
||||
(currentOrg) => {
|
||||
const newOrg = userOrgInfo.find(
|
||||
(newOrg) => newOrg.orgId === currentOrg.orgId
|
||||
);
|
||||
return newOrg && newOrg.roleId !== currentOrg.roleId;
|
||||
}
|
||||
);
|
||||
// Sync roles 1:1 with IdP policy for existing auto-provisioned orgs
|
||||
for (const currentOrg of autoProvisionedOrgs) {
|
||||
const newRole = userOrgInfo.find(
|
||||
(newOrg) => newOrg.orgId === currentOrg.orgId
|
||||
);
|
||||
if (!newRole) continue;
|
||||
|
||||
if (orgsToUpdate.length > 0) {
|
||||
for (const org of orgsToUpdate) {
|
||||
const newRole = userOrgInfo.find(
|
||||
(newOrg) => newOrg.orgId === org.orgId
|
||||
await trx
|
||||
.delete(userOrgRoles)
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgRoles.userId, userId!),
|
||||
eq(userOrgRoles.orgId, currentOrg.orgId)
|
||||
)
|
||||
);
|
||||
if (newRole) {
|
||||
await trx
|
||||
.update(userOrgs)
|
||||
.set({ roleId: newRole.roleId })
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.userId, userId!),
|
||||
eq(userOrgs.orgId, org.orgId)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
for (const roleId of newRole.roleIds) {
|
||||
await trx.insert(userOrgRoles).values({
|
||||
userId: userId!,
|
||||
orgId: currentOrg.orgId,
|
||||
roleId
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -498,15 +622,28 @@ export async function validateOidcCallback(
|
||||
);
|
||||
|
||||
if (orgsToAdd.length > 0) {
|
||||
await trx.insert(userOrgs).values(
|
||||
orgsToAdd.map((org) => ({
|
||||
userId: userId!,
|
||||
orgId: org.orgId,
|
||||
roleId: org.roleId,
|
||||
autoProvisioned: true,
|
||||
dateCreated: new Date().toISOString()
|
||||
}))
|
||||
);
|
||||
for (const org of orgsToAdd) {
|
||||
if (org.roleIds.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const [fullOrg] = await trx
|
||||
.select()
|
||||
.from(orgs)
|
||||
.where(eq(orgs.orgId, org.orgId));
|
||||
if (fullOrg) {
|
||||
await assignUserToOrg(
|
||||
fullOrg,
|
||||
{
|
||||
orgId: org.orgId,
|
||||
userId: userId!,
|
||||
autoProvisioned: true,
|
||||
},
|
||||
org.roleIds,
|
||||
trx
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Loop through all the orgs and get the total number of users from the userOrgs table
|
||||
@@ -527,7 +664,7 @@ export async function validateOidcCallback(
|
||||
});
|
||||
|
||||
for (const orgCount of orgUserCounts) {
|
||||
await usageService.updateDaily(
|
||||
await usageService.updateCount(
|
||||
orgCount.orgId,
|
||||
FeatureId.USERS,
|
||||
orgCount.userCount
|
||||
@@ -545,9 +682,18 @@ export async function validateOidcCallback(
|
||||
|
||||
res.appendHeader("Set-Cookie", cookie);
|
||||
|
||||
let finalRedirectUrl = postAuthRedirectUrl;
|
||||
if (loginPageId) {
|
||||
finalRedirectUrl = `/auth/org/?redirect=${encodeURIComponent(
|
||||
postAuthRedirectUrl
|
||||
)}`;
|
||||
}
|
||||
|
||||
logger.debug("Final redirect URL", { finalRedirectUrl });
|
||||
|
||||
return response<ValidateOidcUrlCallbackResponse>(res, {
|
||||
data: {
|
||||
redirectUrl: postAuthRedirectUrl
|
||||
redirectUrl: finalRedirectUrl
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
@@ -620,3 +766,25 @@ function hydrateOrgMapping(
|
||||
}
|
||||
return orgMapping.split("{{orgId}}").join(orgId);
|
||||
}
|
||||
|
||||
function normalizeRoleMappingResult(
|
||||
result: unknown
|
||||
): string[] {
|
||||
if (typeof result === "string") {
|
||||
const role = result.trim();
|
||||
return role ? [role] : [];
|
||||
}
|
||||
|
||||
if (Array.isArray(result)) {
|
||||
return [
|
||||
...new Set(
|
||||
result
|
||||
.filter((value): value is string => typeof value === "string")
|
||||
.map((value) => value.trim())
|
||||
.filter(Boolean)
|
||||
)
|
||||
];
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
verifyApiKey,
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction,
|
||||
verifyApiKeyCanSetUserOrgRoles,
|
||||
verifyApiKeySiteAccess,
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyApiKeyTargetAccess,
|
||||
@@ -26,7 +27,9 @@ import {
|
||||
verifyApiKeyIsRoot,
|
||||
verifyApiKeyClientAccess,
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceClients
|
||||
verifyApiKeySetResourceClients,
|
||||
verifyLimits,
|
||||
verifyApiKeyDomainAccess
|
||||
} from "@server/middlewares";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { Router } from "express";
|
||||
@@ -74,6 +77,7 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
"/org/:orgId",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateOrg),
|
||||
logActionAudit(ActionsEnum.updateOrg),
|
||||
org.updateOrg
|
||||
@@ -90,6 +94,7 @@ authenticated.delete(
|
||||
authenticated.put(
|
||||
"/org/:orgId/site",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createSite),
|
||||
logActionAudit(ActionsEnum.createSite),
|
||||
site.createSite
|
||||
@@ -126,10 +131,18 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
"/site/:siteId",
|
||||
verifyApiKeySiteAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateSite),
|
||||
logActionAudit(ActionsEnum.updateSite),
|
||||
site.updateSite
|
||||
);
|
||||
authenticated.post(
|
||||
"/org/:orgId/reset-bandwidth",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.resetSiteBandwidth),
|
||||
logActionAudit(ActionsEnum.resetSiteBandwidth),
|
||||
org.resetOrgBandwidth
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
"/site/:siteId",
|
||||
@@ -146,9 +159,9 @@ authenticated.get(
|
||||
);
|
||||
// Site Resource endpoints
|
||||
authenticated.put(
|
||||
"/org/:orgId/site/:siteId/resource",
|
||||
"/org/:orgId/site-resource",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeySiteAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createSiteResource),
|
||||
logActionAudit(ActionsEnum.createSiteResource),
|
||||
siteResource.createSiteResource
|
||||
@@ -170,28 +183,23 @@ authenticated.get(
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/site/:siteId/resource/:siteResourceId",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeySiteAccess,
|
||||
"/site-resource/:siteResourceId",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.getSiteResource),
|
||||
siteResource.getSiteResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/org/:orgId/site/:siteId/resource/:siteResourceId",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeySiteAccess,
|
||||
"/site-resource/:siteResourceId",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateSiteResource),
|
||||
logActionAudit(ActionsEnum.updateSiteResource),
|
||||
siteResource.updateSiteResource
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
"/org/:orgId/site/:siteId/resource/:siteResourceId",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeySiteAccess,
|
||||
"/site-resource/:siteResourceId",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.deleteSiteResource),
|
||||
logActionAudit(ActionsEnum.deleteSiteResource),
|
||||
@@ -223,6 +231,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/roles",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
siteResource.setSiteResourceRoles
|
||||
@@ -232,6 +241,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/users",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceUsers,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.setSiteResourceUsers
|
||||
@@ -241,6 +251,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/roles/add",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
siteResource.addRoleToSiteResource
|
||||
@@ -250,6 +261,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/roles/remove",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
siteResource.removeRoleFromSiteResource
|
||||
@@ -259,6 +271,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/users/add",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceUsers,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.addUserToSiteResource
|
||||
@@ -268,6 +281,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/users/remove",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceUsers,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.removeUserFromSiteResource
|
||||
@@ -277,6 +291,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceClients,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.setSiteResourceClients
|
||||
@@ -286,6 +301,7 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients/add",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceClients,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.addClientToSiteResource
|
||||
@@ -295,14 +311,24 @@ authenticated.post(
|
||||
"/site-resource/:siteResourceId/clients/remove",
|
||||
verifyApiKeySiteResourceAccess,
|
||||
verifyApiKeySetResourceClients,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.removeClientFromSiteResource
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId/site-resources",
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
siteResource.batchAddClientToSiteResources
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/org/:orgId/resource",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createResource),
|
||||
logActionAudit(ActionsEnum.createResource),
|
||||
resource.createResource
|
||||
@@ -311,6 +337,7 @@ authenticated.put(
|
||||
authenticated.put(
|
||||
"/org/:orgId/site/:siteId/resource",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createResource),
|
||||
logActionAudit(ActionsEnum.createResource),
|
||||
resource.createResource
|
||||
@@ -337,6 +364,56 @@ authenticated.get(
|
||||
domain.listDomains
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/domain/:domainId",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyDomainAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.getDomain),
|
||||
domain.getDomain
|
||||
);
|
||||
|
||||
authenticated.put(
|
||||
"/org/:orgId/domain",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.createOrgDomain),
|
||||
logActionAudit(ActionsEnum.createOrgDomain),
|
||||
domain.createOrgDomain
|
||||
);
|
||||
|
||||
authenticated.patch(
|
||||
"/org/:orgId/domain/:domainId",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyDomainAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateOrgDomain),
|
||||
domain.updateOrgDomain
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
"/org/:orgId/domain/:domainId",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyDomainAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.deleteOrgDomain),
|
||||
logActionAudit(ActionsEnum.deleteOrgDomain),
|
||||
domain.deleteAccountDomain
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/domain/:domainId/dns-records",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyDomainAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.getDNSRecords),
|
||||
domain.getDNSRecords
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/org/:orgId/domain/:domainId/restart",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyDomainAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.restartOrgDomain),
|
||||
logActionAudit(ActionsEnum.restartOrgDomain),
|
||||
domain.restartOrgDomain
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/invitations",
|
||||
verifyApiKeyOrgAccess,
|
||||
@@ -347,11 +424,20 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
"/org/:orgId/create-invite",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.inviteUser),
|
||||
logActionAudit(ActionsEnum.inviteUser),
|
||||
user.inviteUser
|
||||
);
|
||||
|
||||
authenticated.delete(
|
||||
"/org/:orgId/invitations/:inviteId",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.removeInvitation),
|
||||
logActionAudit(ActionsEnum.removeInvitation),
|
||||
user.removeInvitation
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/resource/:resourceId/roles",
|
||||
verifyApiKeyResourceAccess,
|
||||
@@ -376,6 +462,7 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
"/resource/:resourceId",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateResource),
|
||||
logActionAudit(ActionsEnum.updateResource),
|
||||
resource.updateResource
|
||||
@@ -392,6 +479,7 @@ authenticated.delete(
|
||||
authenticated.put(
|
||||
"/resource/:resourceId/target",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createTarget),
|
||||
logActionAudit(ActionsEnum.createTarget),
|
||||
target.createTarget
|
||||
@@ -407,6 +495,7 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
"/resource/:resourceId/rule",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createResourceRule),
|
||||
logActionAudit(ActionsEnum.createResourceRule),
|
||||
resource.createResourceRule
|
||||
@@ -422,6 +511,7 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
"/resource/:resourceId/rule/:ruleId",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateResourceRule),
|
||||
logActionAudit(ActionsEnum.updateResourceRule),
|
||||
resource.updateResourceRule
|
||||
@@ -445,6 +535,7 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
"/target/:targetId",
|
||||
verifyApiKeyTargetAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateTarget),
|
||||
logActionAudit(ActionsEnum.updateTarget),
|
||||
target.updateTarget
|
||||
@@ -461,11 +552,21 @@ authenticated.delete(
|
||||
authenticated.put(
|
||||
"/org/:orgId/role",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createRole),
|
||||
logActionAudit(ActionsEnum.createRole),
|
||||
role.createRole
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/role/:roleId",
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateRole),
|
||||
logActionAudit(ActionsEnum.updateRole),
|
||||
role.updateRole
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/roles",
|
||||
verifyApiKeyOrgAccess,
|
||||
@@ -492,15 +593,17 @@ authenticated.post(
|
||||
"/role/:roleId/add/:userId",
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyApiKeyUserAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.addUserRole),
|
||||
logActionAudit(ActionsEnum.addUserRole),
|
||||
user.addUserRole
|
||||
user.addUserRoleLegacy
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/resource/:resourceId/roles",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
resource.setResourceRoles
|
||||
@@ -510,6 +613,7 @@ authenticated.post(
|
||||
"/resource/:resourceId/users",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyApiKeySetResourceUsers,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
resource.setResourceUsers
|
||||
@@ -519,6 +623,7 @@ authenticated.post(
|
||||
"/resource/:resourceId/roles/add",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
resource.addRoleToResource
|
||||
@@ -528,6 +633,7 @@ authenticated.post(
|
||||
"/resource/:resourceId/roles/remove",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyApiKeyRoleAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceRoles),
|
||||
logActionAudit(ActionsEnum.setResourceRoles),
|
||||
resource.removeRoleFromResource
|
||||
@@ -537,6 +643,7 @@ authenticated.post(
|
||||
"/resource/:resourceId/users/add",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyApiKeySetResourceUsers,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
resource.addUserToResource
|
||||
@@ -546,6 +653,7 @@ authenticated.post(
|
||||
"/resource/:resourceId/users/remove",
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyApiKeySetResourceUsers,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceUsers),
|
||||
logActionAudit(ActionsEnum.setResourceUsers),
|
||||
resource.removeUserFromResource
|
||||
@@ -554,6 +662,7 @@ authenticated.post(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/password`,
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourcePassword),
|
||||
logActionAudit(ActionsEnum.setResourcePassword),
|
||||
resource.setResourcePassword
|
||||
@@ -562,6 +671,7 @@ authenticated.post(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/pincode`,
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourcePincode),
|
||||
logActionAudit(ActionsEnum.setResourcePincode),
|
||||
resource.setResourcePincode
|
||||
@@ -570,6 +680,7 @@ authenticated.post(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/header-auth`,
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceHeaderAuth),
|
||||
logActionAudit(ActionsEnum.setResourceHeaderAuth),
|
||||
resource.setResourceHeaderAuth
|
||||
@@ -578,6 +689,7 @@ authenticated.post(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/whitelist`,
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceWhitelist),
|
||||
logActionAudit(ActionsEnum.setResourceWhitelist),
|
||||
resource.setResourceWhitelist
|
||||
@@ -586,6 +698,7 @@ authenticated.post(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/whitelist/add`,
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceWhitelist),
|
||||
resource.addEmailToResourceWhitelist
|
||||
);
|
||||
@@ -593,6 +706,7 @@ authenticated.post(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/whitelist/remove`,
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setResourceWhitelist),
|
||||
resource.removeEmailFromResourceWhitelist
|
||||
);
|
||||
@@ -607,6 +721,7 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
`/resource/:resourceId/access-token`,
|
||||
verifyApiKeyResourceAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.generateAccessToken),
|
||||
logActionAudit(ActionsEnum.generateAccessToken),
|
||||
accessToken.generateAccessToken
|
||||
@@ -641,9 +756,17 @@ authenticated.get(
|
||||
user.getOrgUser
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/user-by-username",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.getOrgUser),
|
||||
user.getOrgUserByUsername
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/user/:userId/2fa",
|
||||
verifyApiKeyIsRoot,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateUser),
|
||||
logActionAudit(ActionsEnum.updateUser),
|
||||
user.updateUser2FA
|
||||
@@ -666,6 +789,7 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
"/org/:orgId/user",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createOrgUser),
|
||||
logActionAudit(ActionsEnum.createOrgUser),
|
||||
user.createOrgUser
|
||||
@@ -675,6 +799,7 @@ authenticated.post(
|
||||
"/org/:orgId/user/:userId",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyUserAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateOrgUser),
|
||||
logActionAudit(ActionsEnum.updateOrgUser),
|
||||
user.updateOrgUser
|
||||
@@ -705,6 +830,7 @@ authenticated.get(
|
||||
authenticated.post(
|
||||
`/org/:orgId/api-key/:apiKeyId/actions`,
|
||||
verifyApiKeyIsRoot,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.setApiKeyActions),
|
||||
logActionAudit(ActionsEnum.setApiKeyActions),
|
||||
apiKeys.setApiKeyActions
|
||||
@@ -720,6 +846,7 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
`/org/:orgId/api-key`,
|
||||
verifyApiKeyIsRoot,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createApiKey),
|
||||
logActionAudit(ActionsEnum.createApiKey),
|
||||
apiKeys.createOrgApiKey
|
||||
@@ -736,6 +863,7 @@ authenticated.delete(
|
||||
authenticated.put(
|
||||
"/idp/oidc",
|
||||
verifyApiKeyIsRoot,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createIdp),
|
||||
logActionAudit(ActionsEnum.createIdp),
|
||||
idp.createOidcIdp
|
||||
@@ -744,15 +872,17 @@ authenticated.put(
|
||||
authenticated.post(
|
||||
"/idp/:idpId/oidc",
|
||||
verifyApiKeyIsRoot,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateIdp),
|
||||
logActionAudit(ActionsEnum.updateIdp),
|
||||
idp.updateOidcIdp
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/idp",
|
||||
verifyApiKeyIsRoot,
|
||||
verifyApiKeyHasAction(ActionsEnum.listIdps),
|
||||
"/idp", // no guards on this because anyone can list idps for login purposes
|
||||
// we do the same for the external api
|
||||
// verifyApiKeyIsRoot,
|
||||
// verifyApiKeyHasAction(ActionsEnum.listIdps),
|
||||
idp.listIdps
|
||||
);
|
||||
|
||||
@@ -766,6 +896,7 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
"/idp/:idpId/org/:orgId",
|
||||
verifyApiKeyIsRoot,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createIdpOrg),
|
||||
logActionAudit(ActionsEnum.createIdpOrg),
|
||||
idp.createIdpOrgPolicy
|
||||
@@ -774,6 +905,7 @@ authenticated.put(
|
||||
authenticated.post(
|
||||
"/idp/:idpId/org/:orgId",
|
||||
verifyApiKeyIsRoot,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateIdpOrg),
|
||||
logActionAudit(ActionsEnum.updateIdpOrg),
|
||||
idp.updateIdpOrgPolicy
|
||||
@@ -808,6 +940,13 @@ authenticated.get(
|
||||
client.listClients
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/user-devices",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.listClients),
|
||||
client.listUserDevices
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/client/:clientId",
|
||||
verifyApiKeyClientAccess,
|
||||
@@ -818,6 +957,7 @@ authenticated.get(
|
||||
authenticated.put(
|
||||
"/org/:orgId/client",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.createClient),
|
||||
logActionAudit(ActionsEnum.createClient),
|
||||
client.createClient
|
||||
@@ -841,9 +981,46 @@ authenticated.delete(
|
||||
client.deleteClient
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId/archive",
|
||||
verifyApiKeyClientAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.archiveClient),
|
||||
logActionAudit(ActionsEnum.archiveClient),
|
||||
client.archiveClient
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId/unarchive",
|
||||
verifyApiKeyClientAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.unarchiveClient),
|
||||
logActionAudit(ActionsEnum.unarchiveClient),
|
||||
client.unarchiveClient
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId/block",
|
||||
verifyApiKeyClientAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.blockClient),
|
||||
logActionAudit(ActionsEnum.blockClient),
|
||||
client.blockClient
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId/unblock",
|
||||
verifyApiKeyClientAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.unblockClient),
|
||||
logActionAudit(ActionsEnum.unblockClient),
|
||||
client.unblockClient
|
||||
);
|
||||
|
||||
authenticated.post(
|
||||
"/client/:clientId",
|
||||
verifyApiKeyClientAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.updateClient),
|
||||
logActionAudit(ActionsEnum.updateClient),
|
||||
client.updateClient
|
||||
@@ -852,11 +1029,26 @@ authenticated.post(
|
||||
authenticated.put(
|
||||
"/org/:orgId/blueprint",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyApiKeyHasAction(ActionsEnum.applyBlueprint),
|
||||
logActionAudit(ActionsEnum.applyBlueprint),
|
||||
blueprints.applyJSONBlueprint
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/blueprint/:blueprintId",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.getBlueprint),
|
||||
blueprints.getBlueprint
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/blueprints",
|
||||
verifyApiKeyOrgAccess,
|
||||
verifyApiKeyHasAction(ActionsEnum.listBlueprints),
|
||||
blueprints.listBlueprints
|
||||
);
|
||||
|
||||
authenticated.get(
|
||||
"/org/:orgId/logs/request",
|
||||
verifyApiKeyOrgAccess,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { LoginPage } from "@server/db";
|
||||
import type { LoginPage, LoginPageBranding } from "@server/db";
|
||||
|
||||
export type CreateLoginPageResponse = LoginPage;
|
||||
|
||||
@@ -9,3 +9,10 @@ export type GetLoginPageResponse = LoginPage;
|
||||
export type UpdateLoginPageResponse = LoginPage;
|
||||
|
||||
export type LoadLoginPageResponse = LoginPage & { orgId: string };
|
||||
|
||||
export type LoadLoginPageBrandingResponse = LoginPageBranding & {
|
||||
orgId: string;
|
||||
orgName: string;
|
||||
};
|
||||
|
||||
export type GetLoginPageBrandingResponse = LoginPageBranding;
|
||||
|
||||
299
server/routers/newt/buildConfiguration.ts
Normal file
299
server/routers/newt/buildConfiguration.ts
Normal file
@@ -0,0 +1,299 @@
|
||||
import {
|
||||
clients,
|
||||
clientSiteResourcesAssociationsCache,
|
||||
clientSitesAssociationsCache,
|
||||
db,
|
||||
ExitNode,
|
||||
resources,
|
||||
Site,
|
||||
siteResources,
|
||||
targetHealthCheck,
|
||||
targets
|
||||
} from "@server/db";
|
||||
import logger from "@server/logger";
|
||||
import { initPeerAddHandshake, updatePeer } from "../olm/peers";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import config from "@server/lib/config";
|
||||
import {
|
||||
formatEndpoint,
|
||||
generateSubnetProxyTargetV2,
|
||||
SubnetProxyTargetV2
|
||||
} from "@server/lib/ip";
|
||||
|
||||
export async function buildClientConfigurationForNewtClient(
|
||||
site: Site,
|
||||
exitNode?: ExitNode
|
||||
) {
|
||||
const siteId = site.siteId;
|
||||
|
||||
// Get all clients connected to this site
|
||||
const clientsRes = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
eq(clients.clientId, clientSitesAssociationsCache.clientId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.siteId, siteId));
|
||||
|
||||
let peers: Array<{
|
||||
publicKey: string;
|
||||
allowedIps: string[];
|
||||
endpoint?: string;
|
||||
}> = [];
|
||||
|
||||
if (site.publicKey && site.endpoint && exitNode) {
|
||||
// Prepare peers data for the response
|
||||
peers = await Promise.all(
|
||||
clientsRes
|
||||
.filter((client) => {
|
||||
if (!client.clients.pubKey) {
|
||||
logger.warn(
|
||||
`Client ${client.clients.clientId} has no public key, skipping`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
if (!client.clients.subnet) {
|
||||
logger.warn(
|
||||
`Client ${client.clients.clientId} has no subnet, skipping`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.map(async (client) => {
|
||||
// Add or update this peer on the olm if it is connected
|
||||
|
||||
// const allSiteResources = await db // only get the site resources that this client has access to
|
||||
// .select()
|
||||
// .from(siteResources)
|
||||
// .innerJoin(
|
||||
// clientSiteResourcesAssociationsCache,
|
||||
// eq(
|
||||
// siteResources.siteResourceId,
|
||||
// clientSiteResourcesAssociationsCache.siteResourceId
|
||||
// )
|
||||
// )
|
||||
// .where(
|
||||
// and(
|
||||
// eq(siteResources.siteId, site.siteId),
|
||||
// eq(
|
||||
// clientSiteResourcesAssociationsCache.clientId,
|
||||
// client.clients.clientId
|
||||
// )
|
||||
// )
|
||||
// );
|
||||
|
||||
if (!client.clientSitesAssociationsCache.isJitMode) { // if we are adding sites through jit then dont add the site to the olm
|
||||
// update the peer info on the olm
|
||||
// if the peer has not been added yet this will be a no-op
|
||||
await updatePeer(client.clients.clientId, {
|
||||
siteId: site.siteId,
|
||||
endpoint: site.endpoint!,
|
||||
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
|
||||
publicKey: site.publicKey!,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort
|
||||
// remoteSubnets: generateRemoteSubnets(
|
||||
// allSiteResources.map(
|
||||
// ({ siteResources }) => siteResources
|
||||
// )
|
||||
// ),
|
||||
// aliases: generateAliasConfig(
|
||||
// allSiteResources.map(
|
||||
// ({ siteResources }) => siteResources
|
||||
// )
|
||||
// )
|
||||
});
|
||||
|
||||
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
|
||||
// if it has already been added this will be a no-op
|
||||
await initPeerAddHandshake(
|
||||
// this will kick off the add peer process for the client
|
||||
client.clients.clientId,
|
||||
{
|
||||
siteId,
|
||||
exitNode: {
|
||||
publicKey: exitNode.publicKey,
|
||||
endpoint: exitNode.endpoint
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
publicKey: client.clients.pubKey!,
|
||||
allowedIps: [
|
||||
`${client.clients.subnet.split("/")[0]}/32`
|
||||
], // we want to only allow from that client
|
||||
endpoint: client.clientSitesAssociationsCache.isRelayed
|
||||
? ""
|
||||
: client.clientSitesAssociationsCache.endpoint! // if its relayed it should be localhost
|
||||
};
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Filter out any null values from peers that didn't have an olm
|
||||
const validPeers = peers.filter((peer) => peer !== null);
|
||||
|
||||
// Get all enabled site resources for this site
|
||||
const allSiteResources = await db
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.where(eq(siteResources.siteId, siteId));
|
||||
|
||||
const targetsToSend: SubnetProxyTargetV2[] = [];
|
||||
|
||||
for (const resource of allSiteResources) {
|
||||
// Get clients associated with this specific resource
|
||||
const resourceClients = await db
|
||||
.select({
|
||||
clientId: clients.clientId,
|
||||
pubKey: clients.pubKey,
|
||||
subnet: clients.subnet
|
||||
})
|
||||
.from(clients)
|
||||
.innerJoin(
|
||||
clientSiteResourcesAssociationsCache,
|
||||
eq(
|
||||
clients.clientId,
|
||||
clientSiteResourcesAssociationsCache.clientId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.siteResourceId,
|
||||
resource.siteResourceId
|
||||
)
|
||||
);
|
||||
|
||||
const resourceTarget = generateSubnetProxyTargetV2(
|
||||
resource,
|
||||
resourceClients
|
||||
);
|
||||
|
||||
if (resourceTarget) {
|
||||
targetsToSend.push(resourceTarget);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
peers: validPeers,
|
||||
targets: targetsToSend
|
||||
};
|
||||
}
|
||||
|
||||
export async function buildTargetConfigurationForNewtClient(siteId: number) {
|
||||
// Get all enabled targets with their resource protocol information
|
||||
const allTargets = await db
|
||||
.select({
|
||||
resourceId: targets.resourceId,
|
||||
targetId: targets.targetId,
|
||||
ip: targets.ip,
|
||||
method: targets.method,
|
||||
port: targets.port,
|
||||
internalPort: targets.internalPort,
|
||||
enabled: targets.enabled,
|
||||
protocol: resources.protocol,
|
||||
hcEnabled: targetHealthCheck.hcEnabled,
|
||||
hcPath: targetHealthCheck.hcPath,
|
||||
hcScheme: targetHealthCheck.hcScheme,
|
||||
hcMode: targetHealthCheck.hcMode,
|
||||
hcHostname: targetHealthCheck.hcHostname,
|
||||
hcPort: targetHealthCheck.hcPort,
|
||||
hcInterval: targetHealthCheck.hcInterval,
|
||||
hcUnhealthyInterval: targetHealthCheck.hcUnhealthyInterval,
|
||||
hcTimeout: targetHealthCheck.hcTimeout,
|
||||
hcHeaders: targetHealthCheck.hcHeaders,
|
||||
hcMethod: targetHealthCheck.hcMethod,
|
||||
hcTlsServerName: targetHealthCheck.hcTlsServerName,
|
||||
hcStatus: targetHealthCheck.hcStatus
|
||||
})
|
||||
.from(targets)
|
||||
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
|
||||
.leftJoin(
|
||||
targetHealthCheck,
|
||||
eq(targets.targetId, targetHealthCheck.targetId)
|
||||
)
|
||||
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
|
||||
|
||||
const { tcpTargets, udpTargets } = allTargets.reduce(
|
||||
(acc, target) => {
|
||||
// Filter out invalid targets
|
||||
if (!target.internalPort || !target.ip || !target.port) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
// Format target into string (handles IPv6 bracketing)
|
||||
const formattedTarget = `${target.internalPort}:${formatEndpoint(target.ip, target.port)}`;
|
||||
|
||||
// Add to the appropriate protocol array
|
||||
if (target.protocol === "tcp") {
|
||||
acc.tcpTargets.push(formattedTarget);
|
||||
} else {
|
||||
acc.udpTargets.push(formattedTarget);
|
||||
}
|
||||
|
||||
return acc;
|
||||
},
|
||||
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
|
||||
);
|
||||
|
||||
const healthCheckTargets = allTargets.map((target) => {
|
||||
// make sure the stuff is defined
|
||||
if (
|
||||
!target.hcPath ||
|
||||
!target.hcHostname ||
|
||||
!target.hcPort ||
|
||||
!target.hcInterval ||
|
||||
!target.hcMethod
|
||||
) {
|
||||
// logger.debug(
|
||||
// `Skipping adding target health check ${target.targetId} due to missing health check fields`
|
||||
// );
|
||||
return null; // Skip targets with missing health check fields
|
||||
}
|
||||
|
||||
// parse headers
|
||||
const hcHeadersParse = target.hcHeaders
|
||||
? JSON.parse(target.hcHeaders)
|
||||
: null;
|
||||
const hcHeadersSend: { [key: string]: string } = {};
|
||||
if (hcHeadersParse) {
|
||||
hcHeadersParse.forEach(
|
||||
(header: { name: string; value: string }) => {
|
||||
hcHeadersSend[header.name] = header.value;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
id: target.targetId,
|
||||
hcEnabled: target.hcEnabled,
|
||||
hcPath: target.hcPath,
|
||||
hcScheme: target.hcScheme,
|
||||
hcMode: target.hcMode,
|
||||
hcHostname: target.hcHostname,
|
||||
hcPort: target.hcPort,
|
||||
hcInterval: target.hcInterval, // in seconds
|
||||
hcUnhealthyInterval: target.hcUnhealthyInterval, // in seconds
|
||||
hcTimeout: target.hcTimeout, // in seconds
|
||||
hcHeaders: hcHeadersSend,
|
||||
hcMethod: target.hcMethod,
|
||||
hcTlsServerName: target.hcTlsServerName,
|
||||
hcStatus: target.hcStatus
|
||||
};
|
||||
});
|
||||
|
||||
// Filter out any null values from health check targets
|
||||
const validHealthCheckTargets = healthCheckTargets.filter(
|
||||
(target) => target !== null
|
||||
);
|
||||
|
||||
return {
|
||||
validHealthCheckTargets,
|
||||
tcpTargets,
|
||||
udpTargets
|
||||
};
|
||||
}
|
||||
@@ -46,7 +46,7 @@ export async function createNewt(
|
||||
|
||||
const { newtId, secret } = parsedBody.data;
|
||||
|
||||
if (req.user && !req.userOrgRoleId) {
|
||||
if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) {
|
||||
return next(
|
||||
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
|
||||
);
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { generateSessionToken } from "@server/auth/sessions/app";
|
||||
import { db } from "@server/db";
|
||||
import { db, newtSessions } from "@server/db";
|
||||
import { newts } from "@server/db";
|
||||
import { getOrCreateCachedToken } from "#dynamic/lib/tokenCache";
|
||||
import { EXPIRES } from "@server/auth/sessions/newt";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import response from "@server/lib/response";
|
||||
import { eq } from "drizzle-orm";
|
||||
@@ -92,8 +94,19 @@ export async function getNewtToken(
|
||||
);
|
||||
}
|
||||
|
||||
const resToken = generateSessionToken();
|
||||
await createNewtSession(resToken, existingNewt.newtId);
|
||||
// Return a cached token if one exists to prevent thundering herd on
|
||||
// simultaneous restarts; falls back to creating a fresh session when
|
||||
// Redis is unavailable or the cache has expired.
|
||||
const resToken = await getOrCreateCachedToken(
|
||||
`newt:token_cache:${existingNewt.newtId}`,
|
||||
config.getRawConfig().server.secret!,
|
||||
Math.floor(EXPIRES / 1000),
|
||||
async () => {
|
||||
const token = generateSessionToken();
|
||||
await createNewtSession(token, existingNewt.newtId);
|
||||
return token;
|
||||
}
|
||||
);
|
||||
|
||||
return response<{ token: string; serverVersion: string }>(res, {
|
||||
data: {
|
||||
|
||||
13
server/routers/newt/handleConnectionLogMessage.ts
Normal file
13
server/routers/newt/handleConnectionLogMessage.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
|
||||
export async function flushConnectionLogToDb(): Promise<void> {
|
||||
return;
|
||||
}
|
||||
|
||||
export async function cleanUpOldLogs(orgId: string, retentionDays: number) {
|
||||
return;
|
||||
}
|
||||
|
||||
export const handleConnectionLogMessage: MessageHandler = async (context) => {
|
||||
return;
|
||||
};
|
||||
@@ -2,23 +2,17 @@ import { z } from "zod";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import {
|
||||
db,
|
||||
ExitNode,
|
||||
exitNodes,
|
||||
siteResources,
|
||||
clientSiteResourcesAssociationsCache
|
||||
} from "@server/db";
|
||||
import { clients, clientSitesAssociationsCache, Newt, sites } from "@server/db";
|
||||
import { db, ExitNode, exitNodes, Newt, sites } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { initPeerAddHandshake, updatePeer } from "../olm/peers";
|
||||
import { sendToExitNode } from "#dynamic/lib/exitNodes";
|
||||
import { generateSubnetProxyTargets, SubnetProxyTarget } from "@server/lib/ip";
|
||||
import config from "@server/lib/config";
|
||||
import { buildClientConfigurationForNewtClient } from "./buildConfiguration";
|
||||
import { convertTargetsIfNessicary } from "../client/targets";
|
||||
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||
|
||||
const inputSchema = z.object({
|
||||
publicKey: z.string(),
|
||||
port: z.int().positive()
|
||||
port: z.int().positive(),
|
||||
chainId: z.string()
|
||||
});
|
||||
|
||||
type Input = z.infer<typeof inputSchema>;
|
||||
@@ -50,7 +44,7 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const { publicKey, port } = message.data as Input;
|
||||
const { publicKey, port, chainId } = message.data as Input;
|
||||
const siteId = newt.siteId;
|
||||
|
||||
// Get the current site data
|
||||
@@ -113,11 +107,11 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
|
||||
const payload = {
|
||||
oldDestination: {
|
||||
destinationIP: existingSite.subnet?.split("/")[0],
|
||||
destinationPort: existingSite.listenPort
|
||||
destinationPort: existingSite.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
|
||||
},
|
||||
newDestination: {
|
||||
destinationIP: site.subnet?.split("/")[0],
|
||||
destinationPort: site.listenPort
|
||||
destinationPort: site.listenPort || 1 // this satisfies gerbil for now but should be reevaluated
|
||||
}
|
||||
};
|
||||
|
||||
@@ -130,169 +124,26 @@ export const handleGetConfigMessage: MessageHandler = async (context) => {
|
||||
}
|
||||
}
|
||||
|
||||
// Get all clients connected to this site
|
||||
const clientsRes = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
eq(clients.clientId, clientSitesAssociationsCache.clientId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.siteId, siteId));
|
||||
const { peers, targets } = await buildClientConfigurationForNewtClient(
|
||||
site,
|
||||
exitNode
|
||||
);
|
||||
|
||||
let peers: Array<{
|
||||
publicKey: string;
|
||||
allowedIps: string[];
|
||||
endpoint?: string;
|
||||
}> = [];
|
||||
const targetsToSend = await convertTargetsIfNessicary(newt.newtId, targets);
|
||||
|
||||
if (site.publicKey && site.endpoint && exitNode) {
|
||||
// Prepare peers data for the response
|
||||
peers = await Promise.all(
|
||||
clientsRes
|
||||
.filter((client) => {
|
||||
if (!client.clients.pubKey) {
|
||||
logger.warn(
|
||||
`Client ${client.clients.clientId} has no public key, skipping`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
if (!client.clients.subnet) {
|
||||
logger.warn(
|
||||
`Client ${client.clients.clientId} has no subnet, skipping`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.map(async (client) => {
|
||||
// Add or update this peer on the olm if it is connected
|
||||
|
||||
// const allSiteResources = await db // only get the site resources that this client has access to
|
||||
// .select()
|
||||
// .from(siteResources)
|
||||
// .innerJoin(
|
||||
// clientSiteResourcesAssociationsCache,
|
||||
// eq(
|
||||
// siteResources.siteResourceId,
|
||||
// clientSiteResourcesAssociationsCache.siteResourceId
|
||||
// )
|
||||
// )
|
||||
// .where(
|
||||
// and(
|
||||
// eq(siteResources.siteId, site.siteId),
|
||||
// eq(
|
||||
// clientSiteResourcesAssociationsCache.clientId,
|
||||
// client.clients.clientId
|
||||
// )
|
||||
// )
|
||||
// );
|
||||
|
||||
// update the peer info on the olm
|
||||
// if the peer has not been added yet this will be a no-op
|
||||
await updatePeer(client.clients.clientId, {
|
||||
siteId: site.siteId,
|
||||
endpoint: site.endpoint!,
|
||||
relayEndpoint: `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`,
|
||||
publicKey: site.publicKey!,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort
|
||||
// remoteSubnets: generateRemoteSubnets(
|
||||
// allSiteResources.map(
|
||||
// ({ siteResources }) => siteResources
|
||||
// )
|
||||
// ),
|
||||
// aliases: generateAliasConfig(
|
||||
// allSiteResources.map(
|
||||
// ({ siteResources }) => siteResources
|
||||
// )
|
||||
// )
|
||||
});
|
||||
|
||||
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
|
||||
// if it has already been added this will be a no-op
|
||||
await initPeerAddHandshake(
|
||||
// this will kick off the add peer process for the client
|
||||
client.clients.clientId,
|
||||
{
|
||||
siteId,
|
||||
exitNode: {
|
||||
publicKey: exitNode.publicKey,
|
||||
endpoint: exitNode.endpoint
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
return {
|
||||
publicKey: client.clients.pubKey!,
|
||||
allowedIps: [
|
||||
`${client.clients.subnet.split("/")[0]}/32`
|
||||
], // we want to only allow from that client
|
||||
endpoint: client.clientSitesAssociationsCache.isRelayed
|
||||
? ""
|
||||
: client.clientSitesAssociationsCache.endpoint! // if its relayed it should be localhost
|
||||
};
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Filter out any null values from peers that didn't have an olm
|
||||
const validPeers = peers.filter((peer) => peer !== null);
|
||||
|
||||
// Get all enabled site resources for this site
|
||||
const allSiteResources = await db
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.where(eq(siteResources.siteId, siteId));
|
||||
|
||||
const targetsToSend: SubnetProxyTarget[] = [];
|
||||
|
||||
for (const resource of allSiteResources) {
|
||||
// Get clients associated with this specific resource
|
||||
const resourceClients = await db
|
||||
.select({
|
||||
clientId: clients.clientId,
|
||||
pubKey: clients.pubKey,
|
||||
subnet: clients.subnet
|
||||
})
|
||||
.from(clients)
|
||||
.innerJoin(
|
||||
clientSiteResourcesAssociationsCache,
|
||||
eq(
|
||||
clients.clientId,
|
||||
clientSiteResourcesAssociationsCache.clientId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.siteResourceId,
|
||||
resource.siteResourceId
|
||||
)
|
||||
);
|
||||
|
||||
const resourceTargets = generateSubnetProxyTargets(
|
||||
resource,
|
||||
resourceClients
|
||||
);
|
||||
|
||||
targetsToSend.push(...resourceTargets);
|
||||
}
|
||||
|
||||
// Build the configuration response
|
||||
const configResponse = {
|
||||
ipAddress: site.address,
|
||||
peers: validPeers,
|
||||
targets: targetsToSend
|
||||
};
|
||||
|
||||
logger.debug("Sending config: ", configResponse);
|
||||
return {
|
||||
message: {
|
||||
type: "newt/wg/receive-config",
|
||||
data: {
|
||||
...configResponse
|
||||
ipAddress: site.address,
|
||||
peers,
|
||||
targets: targetsToSend,
|
||||
chainId: chainId
|
||||
}
|
||||
},
|
||||
options: {
|
||||
compress: canCompress(newt.version, "newt")
|
||||
},
|
||||
broadcast: false,
|
||||
excludeSender: false
|
||||
};
|
||||
|
||||
36
server/routers/newt/handleNewtDisconnectingMessage.ts
Normal file
36
server/routers/newt/handleNewtDisconnectingMessage.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { db, Newt, sites } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
|
||||
/**
|
||||
* Handles disconnecting messages from sites to show disconnected in the ui
|
||||
*/
|
||||
export const handleNewtDisconnectingMessage: MessageHandler = async (
|
||||
context
|
||||
) => {
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const newt = c as Newt;
|
||||
|
||||
if (!newt) {
|
||||
logger.warn("Newt not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!newt.siteId) {
|
||||
logger.warn("Newt has no client ID!");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Update the client's last ping timestamp
|
||||
await db
|
||||
.update(sites)
|
||||
.set({
|
||||
online: false
|
||||
})
|
||||
.where(eq(sites.siteId, newt.siteId));
|
||||
} catch (error) {
|
||||
logger.error("Error handling disconnecting message", { error });
|
||||
}
|
||||
};
|
||||
163
server/routers/newt/handleNewtPingMessage.ts
Normal file
163
server/routers/newt/handleNewtPingMessage.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
import { db, newts, sites } from "@server/db";
|
||||
import { hasActiveConnections, getClientConfigVersion } from "#dynamic/routers/ws";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { Newt } from "@server/db";
|
||||
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { sendNewtSyncMessage } from "./sync";
|
||||
import { recordPing } from "./pingAccumulator";
|
||||
|
||||
// Track if the offline checker interval is running
|
||||
let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
||||
const OFFLINE_CHECK_INTERVAL = 30 * 1000; // Check every 30 seconds
|
||||
const OFFLINE_THRESHOLD_MS = 2 * 60 * 1000; // 2 minutes
|
||||
|
||||
/**
|
||||
* Starts the background interval that checks for newt sites that haven't
|
||||
* pinged recently and marks them as offline. For backward compatibility,
|
||||
* a site is only marked offline when there is no active WebSocket connection
|
||||
* either — so older newt versions that don't send pings but remain connected
|
||||
* continue to be treated as online.
|
||||
*/
|
||||
export const startNewtOfflineChecker = (): void => {
|
||||
if (offlineCheckerInterval) {
|
||||
return; // Already running
|
||||
}
|
||||
|
||||
offlineCheckerInterval = setInterval(async () => {
|
||||
try {
|
||||
const twoMinutesAgo = Math.floor(
|
||||
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
|
||||
);
|
||||
|
||||
// Find all online newt-type sites that haven't pinged recently
|
||||
// (or have never pinged at all). Join newts to obtain the newtId
|
||||
// needed for the WebSocket connection check.
|
||||
const staleSites = await db
|
||||
.select({
|
||||
siteId: sites.siteId,
|
||||
newtId: newts.newtId,
|
||||
lastPing: sites.lastPing
|
||||
})
|
||||
.from(sites)
|
||||
.innerJoin(newts, eq(newts.siteId, sites.siteId))
|
||||
.where(
|
||||
and(
|
||||
eq(sites.online, true),
|
||||
eq(sites.type, "newt"),
|
||||
or(
|
||||
lt(sites.lastPing, twoMinutesAgo),
|
||||
isNull(sites.lastPing)
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
for (const staleSite of staleSites) {
|
||||
// Backward-compatibility check: if the newt still has an
|
||||
// active WebSocket connection (older clients that don't send
|
||||
// pings), keep the site online.
|
||||
const isConnected = await hasActiveConnections(staleSite.newtId);
|
||||
if (isConnected) {
|
||||
logger.debug(
|
||||
`Newt ${staleSite.newtId} has not pinged recently but is still connected via WebSocket — keeping site ${staleSite.siteId} online`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`Marking site ${staleSite.siteId} offline: newt ${staleSite.newtId} has no recent ping and no active WebSocket connection`
|
||||
);
|
||||
|
||||
await db
|
||||
.update(sites)
|
||||
.set({ online: false })
|
||||
.where(eq(sites.siteId, staleSite.siteId));
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Error in newt offline checker interval", { error });
|
||||
}
|
||||
}, OFFLINE_CHECK_INTERVAL);
|
||||
|
||||
logger.debug("Started newt offline checker interval");
|
||||
};
|
||||
|
||||
/**
|
||||
* Stops the background interval that checks for offline newt sites.
|
||||
*/
|
||||
export const stopNewtOfflineChecker = (): void => {
|
||||
if (offlineCheckerInterval) {
|
||||
clearInterval(offlineCheckerInterval);
|
||||
offlineCheckerInterval = null;
|
||||
logger.info("Stopped newt offline checker interval");
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles ping messages from newt clients.
|
||||
*
|
||||
* On each ping:
|
||||
* - Marks the associated site as online.
|
||||
* - Records the current timestamp as the newt's last-ping time.
|
||||
* - Triggers a config sync if the newt is running an outdated config version.
|
||||
* - Responds with a pong message.
|
||||
*/
|
||||
export const handleNewtPingMessage: MessageHandler = async (context) => {
|
||||
const { message, client: c } = context;
|
||||
const newt = c as Newt;
|
||||
|
||||
if (!newt) {
|
||||
logger.warn("Newt ping message: Newt not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!newt.siteId) {
|
||||
logger.warn("Newt ping message: has no site ID");
|
||||
return;
|
||||
}
|
||||
|
||||
// Record the ping in memory; it will be flushed to the database
|
||||
// periodically by the ping accumulator (every ~10s) in a single
|
||||
// batched UPDATE instead of one query per ping. This prevents
|
||||
// connection pool exhaustion under load, especially with
|
||||
// cross-region latency to the database.
|
||||
recordPing(newt.siteId);
|
||||
|
||||
// Check config version and sync if stale.
|
||||
const configVersion = await getClientConfigVersion(newt.newtId);
|
||||
|
||||
if (
|
||||
message.configVersion != null &&
|
||||
configVersion != null &&
|
||||
configVersion !== message.configVersion
|
||||
) {
|
||||
logger.warn(
|
||||
`Newt ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
|
||||
);
|
||||
|
||||
const [site] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, newt.siteId))
|
||||
.limit(1);
|
||||
|
||||
if (!site) {
|
||||
logger.warn(
|
||||
`Newt ping message: site with ID ${newt.siteId} not found`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
await sendNewtSyncMessage(newt, site);
|
||||
}
|
||||
|
||||
return {
|
||||
message: {
|
||||
type: "pong",
|
||||
data: {
|
||||
timestamp: new Date().toISOString()
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
excludeSender: false
|
||||
};
|
||||
};
|
||||
@@ -1,23 +1,19 @@
|
||||
import { db, ExitNode, exitNodeOrgs, newts, Transaction } from "@server/db";
|
||||
import { db, ExitNode, newts, Transaction } from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { exitNodes, Newt, resources, sites, Target, targets } from "@server/db";
|
||||
import { targetHealthCheck } from "@server/db";
|
||||
import { eq, and, sql, inArray, ne } from "drizzle-orm";
|
||||
import { exitNodes, Newt, sites } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { addPeer, deletePeer } from "../gerbil/peers";
|
||||
import logger from "@server/logger";
|
||||
import config from "@server/lib/config";
|
||||
import {
|
||||
findNextAvailableCidr,
|
||||
getNextAvailableClientSubnet
|
||||
} from "@server/lib/ip";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
import { findNextAvailableCidr } from "@server/lib/ip";
|
||||
import {
|
||||
selectBestExitNode,
|
||||
verifyExitNodeOrgAccess
|
||||
} from "#dynamic/lib/exitNodes";
|
||||
import { fetchContainers } from "./dockerSocket";
|
||||
import { lockManager } from "#dynamic/lib/lock";
|
||||
import { buildTargetConfigurationForNewtClient } from "./buildConfiguration";
|
||||
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||
|
||||
export type ExitNodePingResult = {
|
||||
exitNodeId: number;
|
||||
@@ -29,8 +25,6 @@ export type ExitNodePingResult = {
|
||||
wasPreviouslyConnected: boolean;
|
||||
};
|
||||
|
||||
const numTimesLimitExceededForId: Record<string, number> = {};
|
||||
|
||||
export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
||||
const { message, client, sendToClient } = context;
|
||||
const newt = client as Newt;
|
||||
@@ -49,7 +43,7 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
||||
|
||||
const siteId = newt.siteId;
|
||||
|
||||
const { publicKey, pingResults, newtVersion, backwardsCompatible } =
|
||||
const { publicKey, pingResults, newtVersion, backwardsCompatible, chainId } =
|
||||
message.data;
|
||||
if (!publicKey) {
|
||||
logger.warn("Public key not provided");
|
||||
@@ -95,42 +89,6 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
||||
fetchContainers(newt.newtId);
|
||||
}
|
||||
|
||||
const rejectSiteUptime = await usageService.checkLimitSet(
|
||||
oldSite.orgId,
|
||||
false,
|
||||
FeatureId.SITE_UPTIME
|
||||
);
|
||||
const rejectEgressDataMb = await usageService.checkLimitSet(
|
||||
oldSite.orgId,
|
||||
false,
|
||||
FeatureId.EGRESS_DATA_MB
|
||||
);
|
||||
|
||||
// Do we need to check the users and domains daily limits here?
|
||||
// const rejectUsers = await usageService.checkLimitSet(oldSite.orgId, false, FeatureId.USERS);
|
||||
// const rejectDomains = await usageService.checkLimitSet(oldSite.orgId, false, FeatureId.DOMAINS);
|
||||
|
||||
// if (rejectEgressDataMb || rejectSiteUptime || rejectUsers || rejectDomains) {
|
||||
if (rejectEgressDataMb || rejectSiteUptime) {
|
||||
logger.info(
|
||||
`Usage limits exceeded for org ${oldSite.orgId}. Rejecting newt registration.`
|
||||
);
|
||||
|
||||
// PREVENT FURTHER REGISTRATION ATTEMPTS SO WE DON'T SPAM
|
||||
|
||||
// Increment the limit exceeded count for this site
|
||||
numTimesLimitExceededForId[newt.newtId] =
|
||||
(numTimesLimitExceededForId[newt.newtId] || 0) + 1;
|
||||
|
||||
if (numTimesLimitExceededForId[newt.newtId] > 15) {
|
||||
logger.debug(
|
||||
`Newt ${newt.newtId} has exceeded usage limits 15 times. Terminating...`
|
||||
);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
let siteSubnet = oldSite.subnet;
|
||||
let exitNodeIdToQuery = oldSite.exitNodeId;
|
||||
if (exitNodeId && (oldSite.exitNodeId !== exitNodeId || !oldSite.subnet)) {
|
||||
@@ -233,109 +191,8 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
||||
.where(eq(newts.newtId, newt.newtId));
|
||||
}
|
||||
|
||||
// Get all enabled targets with their resource protocol information
|
||||
const allTargets = await db
|
||||
.select({
|
||||
resourceId: targets.resourceId,
|
||||
targetId: targets.targetId,
|
||||
ip: targets.ip,
|
||||
method: targets.method,
|
||||
port: targets.port,
|
||||
internalPort: targets.internalPort,
|
||||
enabled: targets.enabled,
|
||||
protocol: resources.protocol,
|
||||
hcEnabled: targetHealthCheck.hcEnabled,
|
||||
hcPath: targetHealthCheck.hcPath,
|
||||
hcScheme: targetHealthCheck.hcScheme,
|
||||
hcMode: targetHealthCheck.hcMode,
|
||||
hcHostname: targetHealthCheck.hcHostname,
|
||||
hcPort: targetHealthCheck.hcPort,
|
||||
hcInterval: targetHealthCheck.hcInterval,
|
||||
hcUnhealthyInterval: targetHealthCheck.hcUnhealthyInterval,
|
||||
hcTimeout: targetHealthCheck.hcTimeout,
|
||||
hcHeaders: targetHealthCheck.hcHeaders,
|
||||
hcMethod: targetHealthCheck.hcMethod,
|
||||
hcTlsServerName: targetHealthCheck.hcTlsServerName
|
||||
})
|
||||
.from(targets)
|
||||
.innerJoin(resources, eq(targets.resourceId, resources.resourceId))
|
||||
.leftJoin(
|
||||
targetHealthCheck,
|
||||
eq(targets.targetId, targetHealthCheck.targetId)
|
||||
)
|
||||
.where(and(eq(targets.siteId, siteId), eq(targets.enabled, true)));
|
||||
|
||||
const { tcpTargets, udpTargets } = allTargets.reduce(
|
||||
(acc, target) => {
|
||||
// Filter out invalid targets
|
||||
if (!target.internalPort || !target.ip || !target.port) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
// Format target into string
|
||||
const formattedTarget = `${target.internalPort}:${target.ip}:${target.port}`;
|
||||
|
||||
// Add to the appropriate protocol array
|
||||
if (target.protocol === "tcp") {
|
||||
acc.tcpTargets.push(formattedTarget);
|
||||
} else {
|
||||
acc.udpTargets.push(formattedTarget);
|
||||
}
|
||||
|
||||
return acc;
|
||||
},
|
||||
{ tcpTargets: [] as string[], udpTargets: [] as string[] }
|
||||
);
|
||||
|
||||
const healthCheckTargets = allTargets.map((target) => {
|
||||
// make sure the stuff is defined
|
||||
if (
|
||||
!target.hcPath ||
|
||||
!target.hcHostname ||
|
||||
!target.hcPort ||
|
||||
!target.hcInterval ||
|
||||
!target.hcMethod
|
||||
) {
|
||||
logger.debug(
|
||||
`Skipping target ${target.targetId} due to missing health check fields`
|
||||
);
|
||||
return null; // Skip targets with missing health check fields
|
||||
}
|
||||
|
||||
// parse headers
|
||||
const hcHeadersParse = target.hcHeaders
|
||||
? JSON.parse(target.hcHeaders)
|
||||
: null;
|
||||
const hcHeadersSend: { [key: string]: string } = {};
|
||||
if (hcHeadersParse) {
|
||||
hcHeadersParse.forEach(
|
||||
(header: { name: string; value: string }) => {
|
||||
hcHeadersSend[header.name] = header.value;
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
id: target.targetId,
|
||||
hcEnabled: target.hcEnabled,
|
||||
hcPath: target.hcPath,
|
||||
hcScheme: target.hcScheme,
|
||||
hcMode: target.hcMode,
|
||||
hcHostname: target.hcHostname,
|
||||
hcPort: target.hcPort,
|
||||
hcInterval: target.hcInterval, // in seconds
|
||||
hcUnhealthyInterval: target.hcUnhealthyInterval, // in seconds
|
||||
hcTimeout: target.hcTimeout, // in seconds
|
||||
hcHeaders: hcHeadersSend,
|
||||
hcMethod: target.hcMethod,
|
||||
hcTlsServerName: target.hcTlsServerName
|
||||
};
|
||||
});
|
||||
|
||||
// Filter out any null values from health check targets
|
||||
const validHealthCheckTargets = healthCheckTargets.filter(
|
||||
(target) => target !== null
|
||||
);
|
||||
const { tcpTargets, udpTargets, validHealthCheckTargets } =
|
||||
await buildTargetConfigurationForNewtClient(siteId);
|
||||
|
||||
logger.debug(
|
||||
`Sending health check targets to newt ${newt.newtId}: ${JSON.stringify(validHealthCheckTargets)}`
|
||||
@@ -346,6 +203,7 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
||||
type: "newt/wg/connect",
|
||||
data: {
|
||||
endpoint: `${exitNode.endpoint}:${exitNode.listenPort}`,
|
||||
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
||||
publicKey: exitNode.publicKey,
|
||||
serverIP: exitNode.address.split("/")[0],
|
||||
tunnelIP: siteSubnet.split("/")[0],
|
||||
@@ -353,9 +211,13 @@ export const handleNewtRegisterMessage: MessageHandler = async (context) => {
|
||||
udp: udpTargets,
|
||||
tcp: tcpTargets
|
||||
},
|
||||
healthCheckTargets: validHealthCheckTargets
|
||||
healthCheckTargets: validHealthCheckTargets,
|
||||
chainId: chainId
|
||||
}
|
||||
},
|
||||
options: {
|
||||
compress: canCompress(newt.version, "newt")
|
||||
},
|
||||
broadcast: false, // Send to all clients
|
||||
excludeSender: false // Include sender in broadcast
|
||||
};
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { db } from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { clients, Newt } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { clients } from "@server/db";
|
||||
import { eq, sql } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
|
||||
interface PeerBandwidth {
|
||||
@@ -10,13 +10,152 @@ interface PeerBandwidth {
|
||||
bytesOut: number;
|
||||
}
|
||||
|
||||
interface BandwidthAccumulator {
|
||||
bytesIn: number;
|
||||
bytesOut: number;
|
||||
}
|
||||
|
||||
// Retry configuration for deadlock handling
|
||||
const MAX_RETRIES = 3;
|
||||
const BASE_DELAY_MS = 50;
|
||||
|
||||
// How often to flush accumulated bandwidth data to the database
|
||||
const FLUSH_INTERVAL_MS = 120_000; // 120 seconds
|
||||
|
||||
// In-memory accumulator: publicKey -> { bytesIn, bytesOut }
|
||||
let accumulator = new Map<string, BandwidthAccumulator>();
|
||||
|
||||
/**
|
||||
* Check if an error is a deadlock error
|
||||
*/
|
||||
function isDeadlockError(error: any): boolean {
|
||||
return (
|
||||
error?.code === "40P01" ||
|
||||
error?.cause?.code === "40P01" ||
|
||||
(error?.message && error.message.includes("deadlock"))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a function with retry logic for deadlock handling
|
||||
*/
|
||||
async function withDeadlockRetry<T>(
|
||||
operation: () => Promise<T>,
|
||||
context: string
|
||||
): Promise<T> {
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
try {
|
||||
return await operation();
|
||||
} catch (error: any) {
|
||||
if (isDeadlockError(error) && attempt < MAX_RETRIES) {
|
||||
attempt++;
|
||||
const baseDelay = Math.pow(2, attempt - 1) * BASE_DELAY_MS;
|
||||
const jitter = Math.random() * baseDelay;
|
||||
const delay = baseDelay + jitter;
|
||||
logger.warn(
|
||||
`Deadlock detected in ${context}, retrying attempt ${attempt}/${MAX_RETRIES} after ${delay.toFixed(0)}ms`
|
||||
);
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
continue;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Flush all accumulated bandwidth data to the database.
|
||||
*
|
||||
* Swaps out the accumulator before writing so that any bandwidth messages
|
||||
* received during the flush are captured in the new accumulator rather than
|
||||
* being lost or causing contention. Entries that fail to write are re-queued
|
||||
* back into the accumulator so they will be retried on the next flush.
|
||||
*
|
||||
* This function is exported so that the application's graceful-shutdown
|
||||
* cleanup handler can call it before the process exits.
|
||||
*/
|
||||
export async function flushBandwidthToDb(): Promise<void> {
|
||||
if (accumulator.size === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Atomically swap out the accumulator so new data keeps flowing in
|
||||
// while we write the snapshot to the database.
|
||||
const snapshot = accumulator;
|
||||
accumulator = new Map<string, BandwidthAccumulator>();
|
||||
|
||||
const currentTime = new Date().toISOString();
|
||||
|
||||
// Sort by publicKey for consistent lock ordering across concurrent
|
||||
// writers — this is the same deadlock-prevention strategy used in the
|
||||
// original per-message implementation.
|
||||
const sortedEntries = [...snapshot.entries()].sort(([a], [b]) =>
|
||||
a.localeCompare(b)
|
||||
);
|
||||
|
||||
logger.debug(
|
||||
`Flushing accumulated bandwidth data for ${sortedEntries.length} client(s) to the database`
|
||||
);
|
||||
|
||||
for (const [publicKey, { bytesIn, bytesOut }] of sortedEntries) {
|
||||
try {
|
||||
await withDeadlockRetry(async () => {
|
||||
// Use atomic SQL increment to avoid the SELECT-then-UPDATE
|
||||
// anti-pattern and the races it would introduce.
|
||||
await db
|
||||
.update(clients)
|
||||
.set({
|
||||
// Note: bytesIn from peer goes to megabytesOut (data
|
||||
// sent to client) and bytesOut from peer goes to
|
||||
// megabytesIn (data received from client).
|
||||
megabytesOut: sql`COALESCE(${clients.megabytesOut}, 0) + ${bytesIn}`,
|
||||
megabytesIn: sql`COALESCE(${clients.megabytesIn}, 0) + ${bytesOut}`,
|
||||
lastBandwidthUpdate: currentTime
|
||||
})
|
||||
.where(eq(clients.pubKey, publicKey));
|
||||
}, `flush bandwidth for client ${publicKey}`);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to flush bandwidth for client ${publicKey}:`,
|
||||
error
|
||||
);
|
||||
|
||||
// Re-queue the failed entry so it is retried on the next flush
|
||||
// rather than silently dropped.
|
||||
const existing = accumulator.get(publicKey);
|
||||
if (existing) {
|
||||
existing.bytesIn += bytesIn;
|
||||
existing.bytesOut += bytesOut;
|
||||
} else {
|
||||
accumulator.set(publicKey, { bytesIn, bytesOut });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const flushTimer = setInterval(async () => {
|
||||
try {
|
||||
await flushBandwidthToDb();
|
||||
} catch (error) {
|
||||
logger.error("Unexpected error during periodic bandwidth flush:", error);
|
||||
}
|
||||
}, FLUSH_INTERVAL_MS);
|
||||
|
||||
// Calling unref() means this timer will not keep the Node.js event loop alive
|
||||
// on its own — the process can still exit normally when there is no other work
|
||||
// left. The graceful-shutdown path (see server/cleanup.ts) will call
|
||||
// flushBandwidthToDb() explicitly before process.exit(), so no data is lost.
|
||||
flushTimer.unref();
|
||||
|
||||
export const handleReceiveBandwidthMessage: MessageHandler = async (
|
||||
context
|
||||
) => {
|
||||
const { message, client, sendToClient } = context;
|
||||
const { message } = context;
|
||||
|
||||
if (!message.data.bandwidthData) {
|
||||
logger.warn("No bandwidth data provided");
|
||||
return;
|
||||
}
|
||||
|
||||
const bandwidthData: PeerBandwidth[] = message.data.bandwidthData;
|
||||
@@ -25,30 +164,21 @@ export const handleReceiveBandwidthMessage: MessageHandler = async (
|
||||
throw new Error("Invalid bandwidth data");
|
||||
}
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
for (const peer of bandwidthData) {
|
||||
const { publicKey, bytesIn, bytesOut } = peer;
|
||||
|
||||
// Find the client by public key
|
||||
const [client] = await trx
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.pubKey, publicKey))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Update the client's bandwidth usage
|
||||
await trx
|
||||
.update(clients)
|
||||
.set({
|
||||
megabytesOut: (client.megabytesIn || 0) + bytesIn,
|
||||
megabytesIn: (client.megabytesOut || 0) + bytesOut,
|
||||
lastBandwidthUpdate: new Date().toISOString()
|
||||
})
|
||||
.where(eq(clients.clientId, client.clientId));
|
||||
// Accumulate the incoming data in memory; the periodic timer (and the
|
||||
// shutdown hook) will take care of writing it to the database.
|
||||
for (const { publicKey, bytesIn, bytesOut } of bandwidthData) {
|
||||
// Skip peers that haven't transferred any data — writing zeros to the
|
||||
// database would be a no-op anyway.
|
||||
if (bytesIn <= 0 && bytesOut <= 0) {
|
||||
continue;
|
||||
}
|
||||
});
|
||||
|
||||
const existing = accumulator.get(publicKey);
|
||||
if (existing) {
|
||||
existing.bytesIn += bytesIn;
|
||||
existing.bytesOut += bytesOut;
|
||||
} else {
|
||||
accumulator.set(publicKey, { bytesIn, bytesOut });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -2,7 +2,7 @@ import { MessageHandler } from "@server/routers/ws";
|
||||
import logger from "@server/logger";
|
||||
import { Newt } from "@server/db";
|
||||
import { applyNewtDockerBlueprint } from "@server/lib/blueprints/applyNewtDockerBlueprint";
|
||||
import cache from "@server/lib/cache";
|
||||
import cache from "#dynamic/lib/cache";
|
||||
|
||||
export const handleDockerStatusMessage: MessageHandler = async (context) => {
|
||||
const { message, client, sendToClient } = context;
|
||||
@@ -24,8 +24,8 @@ export const handleDockerStatusMessage: MessageHandler = async (context) => {
|
||||
|
||||
if (available) {
|
||||
logger.info(`Newt ${newt.newtId} has Docker socket access`);
|
||||
cache.set(`${newt.newtId}:socketPath`, socketPath, 0);
|
||||
cache.set(`${newt.newtId}:isAvailable`, available, 0);
|
||||
await cache.set(`${newt.newtId}:socketPath`, socketPath, 0);
|
||||
await cache.set(`${newt.newtId}:isAvailable`, available, 0);
|
||||
} else {
|
||||
logger.warn(`Newt ${newt.newtId} does not have Docker socket access`);
|
||||
}
|
||||
@@ -54,7 +54,7 @@ export const handleDockerContainersMessage: MessageHandler = async (
|
||||
);
|
||||
|
||||
if (containers && containers.length > 0) {
|
||||
cache.set(`${newt.newtId}:dockerContainers`, containers, 0);
|
||||
await cache.set(`${newt.newtId}:dockerContainers`, containers, 0);
|
||||
} else {
|
||||
logger.warn(`Newt ${newt.newtId} does not have Docker containers`);
|
||||
}
|
||||
|
||||
@@ -6,3 +6,7 @@ export * from "./handleGetConfigMessage";
|
||||
export * from "./handleSocketMessages";
|
||||
export * from "./handleNewtPingRequestMessage";
|
||||
export * from "./handleApplyBlueprintMessage";
|
||||
export * from "./handleNewtPingMessage";
|
||||
export * from "./handleNewtDisconnectingMessage";
|
||||
export * from "./handleConnectionLogMessage";
|
||||
export * from "./registerNewt";
|
||||
|
||||
@@ -39,7 +39,7 @@ export async function addPeer(
|
||||
await sendToClient(newtId, {
|
||||
type: "newt/wg/peer/add",
|
||||
data: peer
|
||||
}).catch((error) => {
|
||||
}, { incrementConfigVersion: true }).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
@@ -81,7 +81,7 @@ export async function deletePeer(
|
||||
data: {
|
||||
publicKey
|
||||
}
|
||||
}).catch((error) => {
|
||||
}, { incrementConfigVersion: true }).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
@@ -128,7 +128,7 @@ export async function updatePeer(
|
||||
publicKey,
|
||||
...peer
|
||||
}
|
||||
}).catch((error) => {
|
||||
}, { incrementConfigVersion: true }).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
|
||||
382
server/routers/newt/pingAccumulator.ts
Normal file
382
server/routers/newt/pingAccumulator.ts
Normal file
@@ -0,0 +1,382 @@
|
||||
import { db } from "@server/db";
|
||||
import { sites, clients, olms } from "@server/db";
|
||||
import { eq, inArray } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
|
||||
/**
|
||||
* Ping Accumulator
|
||||
*
|
||||
* Instead of writing to the database on every single newt/olm ping (which
|
||||
* causes pool exhaustion under load, especially with cross-region latency),
|
||||
* we accumulate pings in memory and flush them to the database periodically
|
||||
* in a single batch.
|
||||
*
|
||||
* This is the same pattern used for bandwidth flushing in
|
||||
* receiveBandwidth.ts and handleReceiveBandwidthMessage.ts.
|
||||
*
|
||||
* Supports two kinds of pings:
|
||||
* - **Site pings** (from newts): update `sites.online` and `sites.lastPing`
|
||||
* - **Client pings** (from OLMs): update `clients.online`, `clients.lastPing`,
|
||||
* `clients.archived`, and optionally reset `olms.archived`
|
||||
*/
|
||||
|
||||
const FLUSH_INTERVAL_MS = 10_000; // Flush every 10 seconds
|
||||
const MAX_RETRIES = 2;
|
||||
const BASE_DELAY_MS = 50;
|
||||
|
||||
// ── Site (newt) pings ──────────────────────────────────────────────────
|
||||
// Map of siteId -> latest ping timestamp (unix seconds)
|
||||
const pendingSitePings: Map<number, number> = new Map();
|
||||
|
||||
// ── Client (OLM) pings ────────────────────────────────────────────────
|
||||
// Map of clientId -> latest ping timestamp (unix seconds)
|
||||
const pendingClientPings: Map<number, number> = new Map();
|
||||
// Set of olmIds whose `archived` flag should be reset to false
|
||||
const pendingOlmArchiveResets: Set<string> = new Set();
|
||||
|
||||
let flushTimer: NodeJS.Timeout | null = null;
|
||||
|
||||
// ── Public API ─────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Record a ping for a newt site. This does NOT write to the database
|
||||
* immediately. Instead it stores the latest ping timestamp in memory,
|
||||
* to be flushed periodically by the background timer.
|
||||
*/
|
||||
export function recordSitePing(siteId: number): void {
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
pendingSitePings.set(siteId, now);
|
||||
}
|
||||
|
||||
/** @deprecated Use `recordSitePing` instead. Alias kept for existing call-sites. */
|
||||
export const recordPing = recordSitePing;
|
||||
|
||||
/**
|
||||
* Record a ping for an OLM client. Batches the `clients` table update
|
||||
* (`online`, `lastPing`, `archived`) and, when `olmArchived` is true,
|
||||
* also queues an `olms` table update to clear the archived flag.
|
||||
*/
|
||||
export function recordClientPing(
|
||||
clientId: number,
|
||||
olmId: string,
|
||||
olmArchived: boolean
|
||||
): void {
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
pendingClientPings.set(clientId, now);
|
||||
if (olmArchived) {
|
||||
pendingOlmArchiveResets.add(olmId);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Flush Logic ────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Flush all accumulated site pings to the database.
|
||||
*/
|
||||
async function flushSitePingsToDb(): Promise<void> {
|
||||
if (pendingSitePings.size === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Snapshot and clear so new pings arriving during the flush go into a
|
||||
// fresh map for the next cycle.
|
||||
const pingsToFlush = new Map(pendingSitePings);
|
||||
pendingSitePings.clear();
|
||||
|
||||
// Sort by siteId for consistent lock ordering (prevents deadlocks)
|
||||
const sortedEntries = Array.from(pingsToFlush.entries()).sort(
|
||||
([a], [b]) => a - b
|
||||
);
|
||||
|
||||
const BATCH_SIZE = 50;
|
||||
for (let i = 0; i < sortedEntries.length; i += BATCH_SIZE) {
|
||||
const batch = sortedEntries.slice(i, i + BATCH_SIZE);
|
||||
|
||||
try {
|
||||
await withRetry(async () => {
|
||||
// Group by timestamp for efficient bulk updates
|
||||
const byTimestamp = new Map<number, number[]>();
|
||||
for (const [siteId, timestamp] of batch) {
|
||||
const group = byTimestamp.get(timestamp) || [];
|
||||
group.push(siteId);
|
||||
byTimestamp.set(timestamp, group);
|
||||
}
|
||||
|
||||
if (byTimestamp.size === 1) {
|
||||
const [timestamp, siteIds] = Array.from(
|
||||
byTimestamp.entries()
|
||||
)[0];
|
||||
await db
|
||||
.update(sites)
|
||||
.set({
|
||||
online: true,
|
||||
lastPing: timestamp
|
||||
})
|
||||
.where(inArray(sites.siteId, siteIds));
|
||||
} else {
|
||||
await db.transaction(async (tx) => {
|
||||
for (const [timestamp, siteIds] of byTimestamp) {
|
||||
await tx
|
||||
.update(sites)
|
||||
.set({
|
||||
online: true,
|
||||
lastPing: timestamp
|
||||
})
|
||||
.where(inArray(sites.siteId, siteIds));
|
||||
}
|
||||
});
|
||||
}
|
||||
}, "flushSitePingsToDb");
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to flush site ping batch (${batch.length} sites), re-queuing for next cycle`,
|
||||
{ error }
|
||||
);
|
||||
for (const [siteId, timestamp] of batch) {
|
||||
const existing = pendingSitePings.get(siteId);
|
||||
if (!existing || existing < timestamp) {
|
||||
pendingSitePings.set(siteId, timestamp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Flush all accumulated client (OLM) pings to the database.
|
||||
*/
|
||||
async function flushClientPingsToDb(): Promise<void> {
|
||||
if (pendingClientPings.size === 0 && pendingOlmArchiveResets.size === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Snapshot and clear
|
||||
const pingsToFlush = new Map(pendingClientPings);
|
||||
pendingClientPings.clear();
|
||||
|
||||
const olmResetsToFlush = new Set(pendingOlmArchiveResets);
|
||||
pendingOlmArchiveResets.clear();
|
||||
|
||||
// ── Flush client pings ─────────────────────────────────────────────
|
||||
if (pingsToFlush.size > 0) {
|
||||
const sortedEntries = Array.from(pingsToFlush.entries()).sort(
|
||||
([a], [b]) => a - b
|
||||
);
|
||||
|
||||
const BATCH_SIZE = 50;
|
||||
for (let i = 0; i < sortedEntries.length; i += BATCH_SIZE) {
|
||||
const batch = sortedEntries.slice(i, i + BATCH_SIZE);
|
||||
|
||||
try {
|
||||
await withRetry(async () => {
|
||||
const byTimestamp = new Map<number, number[]>();
|
||||
for (const [clientId, timestamp] of batch) {
|
||||
const group = byTimestamp.get(timestamp) || [];
|
||||
group.push(clientId);
|
||||
byTimestamp.set(timestamp, group);
|
||||
}
|
||||
|
||||
if (byTimestamp.size === 1) {
|
||||
const [timestamp, clientIds] = Array.from(
|
||||
byTimestamp.entries()
|
||||
)[0];
|
||||
await db
|
||||
.update(clients)
|
||||
.set({
|
||||
lastPing: timestamp,
|
||||
online: true,
|
||||
archived: false
|
||||
})
|
||||
.where(inArray(clients.clientId, clientIds));
|
||||
} else {
|
||||
await db.transaction(async (tx) => {
|
||||
for (const [timestamp, clientIds] of byTimestamp) {
|
||||
await tx
|
||||
.update(clients)
|
||||
.set({
|
||||
lastPing: timestamp,
|
||||
online: true,
|
||||
archived: false
|
||||
})
|
||||
.where(
|
||||
inArray(clients.clientId, clientIds)
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, "flushClientPingsToDb");
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to flush client ping batch (${batch.length} clients), re-queuing for next cycle`,
|
||||
{ error }
|
||||
);
|
||||
for (const [clientId, timestamp] of batch) {
|
||||
const existing = pendingClientPings.get(clientId);
|
||||
if (!existing || existing < timestamp) {
|
||||
pendingClientPings.set(clientId, timestamp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Flush OLM archive resets ───────────────────────────────────────
|
||||
if (olmResetsToFlush.size > 0) {
|
||||
const olmIds = Array.from(olmResetsToFlush).sort();
|
||||
|
||||
const BATCH_SIZE = 50;
|
||||
for (let i = 0; i < olmIds.length; i += BATCH_SIZE) {
|
||||
const batch = olmIds.slice(i, i + BATCH_SIZE);
|
||||
|
||||
try {
|
||||
await withRetry(async () => {
|
||||
await db
|
||||
.update(olms)
|
||||
.set({ archived: false })
|
||||
.where(inArray(olms.olmId, batch));
|
||||
}, "flushOlmArchiveResets");
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Failed to flush OLM archive reset batch (${batch.length} olms), re-queuing for next cycle`,
|
||||
{ error }
|
||||
);
|
||||
for (const olmId of batch) {
|
||||
pendingOlmArchiveResets.add(olmId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Flush everything — called by the interval timer and during shutdown.
|
||||
*/
|
||||
export async function flushPingsToDb(): Promise<void> {
|
||||
await flushSitePingsToDb();
|
||||
await flushClientPingsToDb();
|
||||
}
|
||||
|
||||
// ── Retry / Error Helpers ──────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Simple retry wrapper with exponential backoff for transient errors
|
||||
* (connection timeouts, unexpected disconnects).
|
||||
*/
|
||||
async function withRetry<T>(
|
||||
operation: () => Promise<T>,
|
||||
context: string
|
||||
): Promise<T> {
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
try {
|
||||
return await operation();
|
||||
} catch (error: any) {
|
||||
if (isTransientError(error) && attempt < MAX_RETRIES) {
|
||||
attempt++;
|
||||
const baseDelay = Math.pow(2, attempt - 1) * BASE_DELAY_MS;
|
||||
const jitter = Math.random() * baseDelay;
|
||||
const delay = baseDelay + jitter;
|
||||
logger.warn(
|
||||
`Transient DB error in ${context}, retrying attempt ${attempt}/${MAX_RETRIES} after ${delay.toFixed(0)}ms`
|
||||
);
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
continue;
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect transient connection errors that are safe to retry.
|
||||
*/
|
||||
function isTransientError(error: any): boolean {
|
||||
if (!error) return false;
|
||||
|
||||
const message = (error.message || "").toLowerCase();
|
||||
const causeMessage = (error.cause?.message || "").toLowerCase();
|
||||
const code = error.code || "";
|
||||
|
||||
// Connection timeout / terminated
|
||||
if (
|
||||
message.includes("connection timeout") ||
|
||||
message.includes("connection terminated") ||
|
||||
message.includes("timeout exceeded when trying to connect") ||
|
||||
causeMessage.includes("connection terminated unexpectedly") ||
|
||||
causeMessage.includes("connection timeout")
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// PostgreSQL deadlock
|
||||
if (code === "40P01" || message.includes("deadlock")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// ECONNRESET, ECONNREFUSED, EPIPE
|
||||
if (
|
||||
code === "ECONNRESET" ||
|
||||
code === "ECONNREFUSED" ||
|
||||
code === "EPIPE" ||
|
||||
code === "ETIMEDOUT"
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// ── Lifecycle ──────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Start the background flush timer. Call this once at server startup.
|
||||
*/
|
||||
export function startPingAccumulator(): void {
|
||||
if (flushTimer) {
|
||||
return; // Already running
|
||||
}
|
||||
|
||||
flushTimer = setInterval(async () => {
|
||||
try {
|
||||
await flushPingsToDb();
|
||||
} catch (error) {
|
||||
logger.error("Unhandled error in ping accumulator flush", {
|
||||
error
|
||||
});
|
||||
}
|
||||
}, FLUSH_INTERVAL_MS);
|
||||
|
||||
// Don't prevent the process from exiting
|
||||
flushTimer.unref();
|
||||
|
||||
logger.info(
|
||||
`Ping accumulator started (flush interval: ${FLUSH_INTERVAL_MS}ms)`
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the background flush timer and perform a final flush.
|
||||
* Call this during graceful shutdown.
|
||||
*/
|
||||
export async function stopPingAccumulator(): Promise<void> {
|
||||
if (flushTimer) {
|
||||
clearInterval(flushTimer);
|
||||
flushTimer = null;
|
||||
}
|
||||
|
||||
// Final flush to persist any remaining pings
|
||||
try {
|
||||
await flushPingsToDb();
|
||||
} catch (error) {
|
||||
logger.error("Error during final ping accumulator flush", { error });
|
||||
}
|
||||
|
||||
logger.info("Ping accumulator stopped");
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of pending (unflushed) pings. Useful for monitoring.
|
||||
*/
|
||||
export function getPendingPingCount(): number {
|
||||
return pendingSitePings.size + pendingClientPings.size;
|
||||
}
|
||||
266
server/routers/newt/registerNewt.ts
Normal file
266
server/routers/newt/registerNewt.ts
Normal file
@@ -0,0 +1,266 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import { db } from "@server/db";
|
||||
import {
|
||||
siteProvisioningKeys,
|
||||
siteProvisioningKeyOrg,
|
||||
newts,
|
||||
orgs,
|
||||
roles,
|
||||
roleSites,
|
||||
sites
|
||||
} from "@server/db";
|
||||
import response from "@server/lib/response";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { eq, and, sql } from "drizzle-orm";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { verifyPassword, hashPassword } from "@server/auth/password";
|
||||
import {
|
||||
generateId,
|
||||
generateIdFromEntropySize
|
||||
} from "@server/auth/sessions/app";
|
||||
import { getUniqueSiteName } from "@server/db/names";
|
||||
import moment from "moment";
|
||||
import { build } from "@server/build";
|
||||
import { usageService } from "@server/lib/billing/usageService";
|
||||
import { FeatureId } from "@server/lib/billing";
|
||||
import { INSPECT_MAX_BYTES } from "buffer";
|
||||
import { v } from "@faker-js/faker/dist/airline-Dz1uGqgJ";
|
||||
|
||||
const bodySchema = z.object({
|
||||
provisioningKey: z.string().nonempty()
|
||||
});
|
||||
|
||||
export type RegisterNewtBody = z.infer<typeof bodySchema>;
|
||||
|
||||
export type RegisterNewtResponse = {
|
||||
newtId: string;
|
||||
secret: string;
|
||||
};
|
||||
|
||||
export async function registerNewt(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { provisioningKey } = parsedBody.data;
|
||||
|
||||
// Keys are in the format "siteProvisioningKeyId.secret"
|
||||
const dotIndex = provisioningKey.indexOf(".");
|
||||
if (dotIndex === -1) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Invalid provisioning key format"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const provisioningKeyId = provisioningKey.substring(0, dotIndex);
|
||||
const provisioningKeySecret = provisioningKey.substring(dotIndex + 1);
|
||||
|
||||
// Look up the provisioning key by ID, joining to get the orgId
|
||||
const [keyRecord] = await db
|
||||
.select({
|
||||
siteProvisioningKeyId:
|
||||
siteProvisioningKeys.siteProvisioningKeyId,
|
||||
siteProvisioningKeyHash:
|
||||
siteProvisioningKeys.siteProvisioningKeyHash,
|
||||
orgId: siteProvisioningKeyOrg.orgId,
|
||||
maxBatchSize: siteProvisioningKeys.maxBatchSize,
|
||||
numUsed: siteProvisioningKeys.numUsed,
|
||||
validUntil: siteProvisioningKeys.validUntil
|
||||
})
|
||||
.from(siteProvisioningKeys)
|
||||
.innerJoin(
|
||||
siteProvisioningKeyOrg,
|
||||
eq(
|
||||
siteProvisioningKeys.siteProvisioningKeyId,
|
||||
siteProvisioningKeyOrg.siteProvisioningKeyId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
eq(
|
||||
siteProvisioningKeys.siteProvisioningKeyId,
|
||||
provisioningKeyId
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!keyRecord) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
"Invalid provisioning key"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Verify the secret portion against the stored hash
|
||||
const validSecret = await verifyPassword(
|
||||
provisioningKeySecret,
|
||||
keyRecord.siteProvisioningKeyHash
|
||||
);
|
||||
if (!validSecret) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
"Invalid provisioning key"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (keyRecord.maxBatchSize && keyRecord.numUsed >= keyRecord.maxBatchSize) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
"Provisioning key has reached its maximum usage"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (keyRecord.validUntil && new Date(keyRecord.validUntil) < new Date()) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.UNAUTHORIZED,
|
||||
"Provisioning key has expired"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { orgId } = keyRecord;
|
||||
|
||||
// Verify the org exists
|
||||
const [org] = await db.select().from(orgs).where(eq(orgs.orgId, orgId));
|
||||
if (!org) {
|
||||
return next(
|
||||
createHttpError(HttpCode.NOT_FOUND, "Organization not found")
|
||||
);
|
||||
}
|
||||
|
||||
// SaaS billing check
|
||||
if (build == "saas") {
|
||||
const usage = await usageService.getUsage(orgId, FeatureId.SITES);
|
||||
if (!usage) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"No usage data found for this organization"
|
||||
)
|
||||
);
|
||||
}
|
||||
const rejectSites = await usageService.checkLimitSet(
|
||||
orgId,
|
||||
FeatureId.SITES,
|
||||
{
|
||||
...usage,
|
||||
instantaneousValue: (usage.instantaneousValue || 0) + 1
|
||||
}
|
||||
);
|
||||
if (rejectSites) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.FORBIDDEN,
|
||||
"Site limit exceeded. Please upgrade your plan."
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const niceId = await getUniqueSiteName(orgId);
|
||||
const newtId = generateId(15);
|
||||
const newtSecret = generateIdFromEntropySize(25);
|
||||
const secretHash = await hashPassword(newtSecret);
|
||||
|
||||
let newSiteId: number | undefined;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
// Create the site (type "newt", name = niceId)
|
||||
const [newSite] = await trx
|
||||
.insert(sites)
|
||||
.values({
|
||||
orgId,
|
||||
name: niceId,
|
||||
niceId,
|
||||
type: "newt",
|
||||
dockerSocketEnabled: true
|
||||
})
|
||||
.returning();
|
||||
|
||||
newSiteId = newSite.siteId;
|
||||
|
||||
// Grant admin role access to the new site
|
||||
const [adminRole] = await trx
|
||||
.select()
|
||||
.from(roles)
|
||||
.where(and(eq(roles.isAdmin, true), eq(roles.orgId, orgId)))
|
||||
.limit(1);
|
||||
|
||||
if (!adminRole) {
|
||||
throw new Error(`Admin role not found for org ${orgId}`);
|
||||
}
|
||||
|
||||
await trx.insert(roleSites).values({
|
||||
roleId: adminRole.roleId,
|
||||
siteId: newSite.siteId
|
||||
});
|
||||
|
||||
// Create the newt for this site
|
||||
await trx.insert(newts).values({
|
||||
newtId,
|
||||
secretHash,
|
||||
siteId: newSite.siteId,
|
||||
dateCreated: moment().toISOString()
|
||||
});
|
||||
|
||||
// Consume the provisioning key — cascade removes siteProvisioningKeyOrg
|
||||
await trx
|
||||
.update(siteProvisioningKeys)
|
||||
.set({
|
||||
lastUsed: moment().toISOString(),
|
||||
numUsed: sql`${siteProvisioningKeys.numUsed} + 1`
|
||||
})
|
||||
.where(
|
||||
eq(
|
||||
siteProvisioningKeys.siteProvisioningKeyId,
|
||||
provisioningKeyId
|
||||
)
|
||||
);
|
||||
|
||||
await usageService.add(orgId, FeatureId.SITES, 1, trx);
|
||||
});
|
||||
|
||||
logger.info(
|
||||
`Provisioned new site (ID: ${newSiteId}) and newt (ID: ${newtId}) for org ${orgId} via provisioning key ${provisioningKeyId}`
|
||||
);
|
||||
|
||||
return response<RegisterNewtResponse>(res, {
|
||||
data: {
|
||||
newtId,
|
||||
secret: newtSecret
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Newt registered successfully",
|
||||
status: HttpCode.CREATED
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
|
||||
);
|
||||
}
|
||||
}
|
||||
48
server/routers/newt/sync.ts
Normal file
48
server/routers/newt/sync.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
import { ExitNode, exitNodes, Newt, Site, db } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import logger from "@server/logger";
|
||||
import {
|
||||
buildClientConfigurationForNewtClient,
|
||||
buildTargetConfigurationForNewtClient
|
||||
} from "./buildConfiguration";
|
||||
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||
|
||||
export async function sendNewtSyncMessage(newt: Newt, site: Site) {
|
||||
const { tcpTargets, udpTargets, validHealthCheckTargets } =
|
||||
await buildTargetConfigurationForNewtClient(site.siteId);
|
||||
|
||||
let exitNode: ExitNode | undefined;
|
||||
if (site.exitNodeId) {
|
||||
[exitNode] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
|
||||
.limit(1);
|
||||
}
|
||||
const { peers, targets } = await buildClientConfigurationForNewtClient(
|
||||
site,
|
||||
exitNode
|
||||
);
|
||||
|
||||
await sendToClient(
|
||||
newt.newtId,
|
||||
{
|
||||
type: "newt/sync",
|
||||
data: {
|
||||
proxyTargets: {
|
||||
udp: udpTargets,
|
||||
tcp: tcpTargets
|
||||
},
|
||||
healthCheckTargets: validHealthCheckTargets,
|
||||
peers: peers,
|
||||
clientTargets: targets
|
||||
}
|
||||
},
|
||||
{
|
||||
compress: canCompress(newt.version, "newt")
|
||||
}
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending newt sync message:`, error);
|
||||
});
|
||||
}
|
||||
@@ -2,13 +2,14 @@ import { Target, TargetHealthCheck, db, targetHealthCheck } from "@server/db";
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import logger from "@server/logger";
|
||||
import { eq, inArray } from "drizzle-orm";
|
||||
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||
|
||||
export async function addTargets(
|
||||
newtId: string,
|
||||
targets: Target[],
|
||||
healthCheckData: TargetHealthCheck[],
|
||||
protocol: string,
|
||||
port: number | null = null
|
||||
version?: string | null
|
||||
) {
|
||||
//create a list of udp and tcp targets
|
||||
const payloadTargets = targets.map((target) => {
|
||||
@@ -22,7 +23,7 @@ export async function addTargets(
|
||||
data: {
|
||||
targets: payloadTargets
|
||||
}
|
||||
});
|
||||
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
|
||||
|
||||
// Create a map for quick lookup
|
||||
const healthCheckMap = new Map<number, TargetHealthCheck>();
|
||||
@@ -103,14 +104,14 @@ export async function addTargets(
|
||||
data: {
|
||||
targets: validHealthCheckTargets
|
||||
}
|
||||
});
|
||||
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
|
||||
}
|
||||
|
||||
export async function removeTargets(
|
||||
newtId: string,
|
||||
targets: Target[],
|
||||
protocol: string,
|
||||
port: number | null = null
|
||||
version?: string | null
|
||||
) {
|
||||
//create a list of udp and tcp targets
|
||||
const payloadTargets = targets.map((target) => {
|
||||
@@ -124,7 +125,7 @@ export async function removeTargets(
|
||||
data: {
|
||||
targets: payloadTargets
|
||||
}
|
||||
});
|
||||
}, { incrementConfigVersion: true });
|
||||
|
||||
const healthCheckTargets = targets.map((target) => {
|
||||
return target.targetId;
|
||||
@@ -135,5 +136,5 @@ export async function removeTargets(
|
||||
data: {
|
||||
ids: healthCheckTargets
|
||||
}
|
||||
});
|
||||
}, { incrementConfigVersion: true, compress: canCompress(version, "newt") });
|
||||
}
|
||||
|
||||
60
server/routers/olm/archiveUserOlm.ts
Normal file
60
server/routers/olm/archiveUserOlm.ts
Normal file
@@ -0,0 +1,60 @@
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db } from "@server/db";
|
||||
import { olms } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import response from "@server/lib/response";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
userId: z.string(),
|
||||
olmId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export async function archiveUserOlm(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId } = parsedParams.data;
|
||||
|
||||
await db.transaction(async (trx) => {
|
||||
await trx
|
||||
.update(olms)
|
||||
.set({ archived: true })
|
||||
.where(eq(olms.olmId, olmId));
|
||||
});
|
||||
|
||||
return response(res, {
|
||||
data: null,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Device archived successfully",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to archive device"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
188
server/routers/olm/buildConfiguration.ts
Normal file
188
server/routers/olm/buildConfiguration.ts
Normal file
@@ -0,0 +1,188 @@
|
||||
import {
|
||||
Client,
|
||||
clientSiteResourcesAssociationsCache,
|
||||
clientSitesAssociationsCache,
|
||||
db,
|
||||
exitNodes,
|
||||
siteResources,
|
||||
sites
|
||||
} from "@server/db";
|
||||
import {
|
||||
Alias,
|
||||
generateAliasConfig,
|
||||
generateRemoteSubnets
|
||||
} from "@server/lib/ip";
|
||||
import logger from "@server/logger";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { addPeer, deletePeer } from "../newt/peers";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
export async function buildSiteConfigurationForOlmClient(
|
||||
client: Client,
|
||||
publicKey: string | null,
|
||||
relay: boolean,
|
||||
jitMode: boolean = false
|
||||
) {
|
||||
const siteConfigurations: {
|
||||
siteId: number;
|
||||
name?: string
|
||||
endpoint?: string
|
||||
publicKey?: string
|
||||
serverIP?: string | null
|
||||
serverPort?: number | null
|
||||
remoteSubnets?: string[];
|
||||
aliases: Alias[];
|
||||
}[] = [];
|
||||
|
||||
// Get all sites data
|
||||
const sitesData = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||
|
||||
// Process each site
|
||||
for (const {
|
||||
sites: site,
|
||||
clientSitesAssociationsCache: association
|
||||
} of sitesData) {
|
||||
const allSiteResources = await db // only get the site resources that this client has access to
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.innerJoin(
|
||||
clientSiteResourcesAssociationsCache,
|
||||
eq(
|
||||
siteResources.siteResourceId,
|
||||
clientSiteResourcesAssociationsCache.siteResourceId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(siteResources.siteId, site.siteId),
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.clientId,
|
||||
client.clientId
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
if (jitMode) {
|
||||
// Add site configuration to the array
|
||||
siteConfigurations.push({
|
||||
siteId: site.siteId,
|
||||
// remoteSubnets: generateRemoteSubnets(
|
||||
// allSiteResources.map(({ siteResources }) => siteResources)
|
||||
// ),
|
||||
aliases: generateAliasConfig(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
)
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!site.exitNodeId) {
|
||||
logger.warn(
|
||||
`Site ${site.siteId} does not have exit node, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate endpoint and hole punch status
|
||||
if (!site.endpoint) {
|
||||
logger.warn(
|
||||
`In olm register: site ${site.siteId} has no endpoint, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!site.publicKey || site.publicKey == "") { // the site is not ready to accept new peers
|
||||
logger.warn(
|
||||
`Site ${site.siteId} has no public key, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// if (site.lastHolePunch && now - site.lastHolePunch > 6 && relay) {
|
||||
// logger.warn(
|
||||
// `Site ${site.siteId} last hole punch is too old, skipping`
|
||||
// );
|
||||
// continue;
|
||||
// }
|
||||
|
||||
// If public key changed, delete old peer from this site
|
||||
if (client.pubKey && client.pubKey != publicKey) {
|
||||
logger.info(
|
||||
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
|
||||
);
|
||||
await deletePeer(site.siteId, client.pubKey!);
|
||||
}
|
||||
|
||||
if (!site.subnet) {
|
||||
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
|
||||
continue;
|
||||
}
|
||||
|
||||
const [clientSite] = await db
|
||||
.select()
|
||||
.from(clientSitesAssociationsCache)
|
||||
.where(
|
||||
and(
|
||||
eq(clientSitesAssociationsCache.clientId, client.clientId),
|
||||
eq(clientSitesAssociationsCache.siteId, site.siteId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
// Add the peer to the exit node for this site
|
||||
if (clientSite.endpoint && publicKey) {
|
||||
logger.info(
|
||||
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}`
|
||||
);
|
||||
await addPeer(site.siteId, {
|
||||
publicKey: publicKey,
|
||||
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
|
||||
endpoint: relay ? "" : clientSite.endpoint
|
||||
});
|
||||
} else {
|
||||
logger.warn(
|
||||
`Client ${client.clientId} has no endpoint, skipping peer addition`
|
||||
);
|
||||
}
|
||||
|
||||
let relayEndpoint: string | undefined = undefined;
|
||||
if (relay) {
|
||||
const [exitNode] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
|
||||
.limit(1);
|
||||
if (!exitNode) {
|
||||
logger.warn(`Exit node not found for site ${site.siteId}`);
|
||||
continue;
|
||||
}
|
||||
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
|
||||
}
|
||||
|
||||
// Add site configuration to the array
|
||||
siteConfigurations.push({
|
||||
siteId: site.siteId,
|
||||
name: site.name,
|
||||
// relayEndpoint: relayEndpoint, // this can be undefined now if not relayed // lets not do this for now because it would conflict with the hole punch testing
|
||||
endpoint: site.endpoint,
|
||||
publicKey: site.publicKey,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort,
|
||||
remoteSubnets: generateRemoteSubnets(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
),
|
||||
aliases: generateAliasConfig(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
return siteConfigurations;
|
||||
}
|
||||
@@ -46,7 +46,7 @@ export async function createNewt(
|
||||
|
||||
const { newtId, secret } = parsedBody.data;
|
||||
|
||||
if (req.user && !req.userOrgRoleId) {
|
||||
if (req.user && (!req.userOrgRoleIds || req.userOrgRoleIds.length === 0)) {
|
||||
return next(
|
||||
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
|
||||
);
|
||||
|
||||
@@ -11,6 +11,7 @@ import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { rebuildClientAssociationsFromClient } from "@server/lib/rebuildClientAssociations";
|
||||
import { sendTerminateClient } from "../client/terminate";
|
||||
import { OlmErrorCodes } from "./error";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
@@ -76,6 +77,7 @@ export async function deleteUserOlm(
|
||||
if (olm) {
|
||||
await sendTerminateClient(
|
||||
deletedClient.clientId,
|
||||
OlmErrorCodes.TERMINATED_DELETED,
|
||||
olm.olmId
|
||||
); // the olmId needs to be provided because it cant look it up after deletion
|
||||
}
|
||||
|
||||
104
server/routers/olm/error.ts
Normal file
104
server/routers/olm/error.ts
Normal file
@@ -0,0 +1,104 @@
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
// Error codes for registration failures
|
||||
export const OlmErrorCodes = {
|
||||
OLM_NOT_FOUND: {
|
||||
code: "OLM_NOT_FOUND",
|
||||
message: "The specified device could not be found."
|
||||
},
|
||||
CLIENT_ID_NOT_FOUND: {
|
||||
code: "CLIENT_ID_NOT_FOUND",
|
||||
message: "No client ID was provided in the request."
|
||||
},
|
||||
CLIENT_NOT_FOUND: {
|
||||
code: "CLIENT_NOT_FOUND",
|
||||
message: "The specified client does not exist."
|
||||
},
|
||||
CLIENT_BLOCKED: {
|
||||
code: "CLIENT_BLOCKED",
|
||||
message:
|
||||
"This client has been blocked in this organization and cannot connect. Please contact your administrator."
|
||||
},
|
||||
CLIENT_PENDING: {
|
||||
code: "CLIENT_PENDING",
|
||||
message:
|
||||
"This client is pending approval and cannot connect yet. Please contact your administrator."
|
||||
},
|
||||
ORG_NOT_FOUND: {
|
||||
code: "ORG_NOT_FOUND",
|
||||
message:
|
||||
"The organization could not be found. Please select a valid organization."
|
||||
},
|
||||
USER_ID_NOT_FOUND: {
|
||||
code: "USER_ID_NOT_FOUND",
|
||||
message: "No user ID was provided in the request."
|
||||
},
|
||||
INVALID_USER_SESSION: {
|
||||
code: "INVALID_USER_SESSION",
|
||||
message:
|
||||
"Your user session is invalid or has expired. Please log in again."
|
||||
},
|
||||
USER_ID_MISMATCH: {
|
||||
code: "USER_ID_MISMATCH",
|
||||
message: "The provided user ID does not match the session."
|
||||
},
|
||||
ORG_ACCESS_POLICY_DENIED: {
|
||||
code: "ORG_ACCESS_POLICY_DENIED",
|
||||
message:
|
||||
"Access to this organization has been denied by policy. Please contact your administrator."
|
||||
},
|
||||
ORG_ACCESS_POLICY_PASSWORD_EXPIRED: {
|
||||
code: "ORG_ACCESS_POLICY_PASSWORD_EXPIRED",
|
||||
message:
|
||||
"Access to this organization has been denied because your password has expired. Please visit this organization's dashboard to update your password."
|
||||
},
|
||||
ORG_ACCESS_POLICY_SESSION_EXPIRED: {
|
||||
code: "ORG_ACCESS_POLICY_SESSION_EXPIRED",
|
||||
message:
|
||||
"Access to this organization has been denied because your session has expired. Please log in again to refresh the session."
|
||||
},
|
||||
ORG_ACCESS_POLICY_2FA_REQUIRED: {
|
||||
code: "ORG_ACCESS_POLICY_2FA_REQUIRED",
|
||||
message:
|
||||
"Access to this organization requires two-factor authentication. Please visit this organization's dashboard to enable two-factor authentication."
|
||||
},
|
||||
TERMINATED_REKEYED: {
|
||||
code: "TERMINATED_REKEYED",
|
||||
message:
|
||||
"This session was terminated because encryption keys were regenerated."
|
||||
},
|
||||
TERMINATED_ORG_DELETED: {
|
||||
code: "TERMINATED_ORG_DELETED",
|
||||
message:
|
||||
"This session was terminated because the organization was deleted."
|
||||
},
|
||||
TERMINATED_INACTIVITY: {
|
||||
code: "TERMINATED_INACTIVITY",
|
||||
message: "This session was terminated due to inactivity."
|
||||
},
|
||||
TERMINATED_DELETED: {
|
||||
code: "TERMINATED_DELETED",
|
||||
message: "This session was terminated because it was deleted."
|
||||
},
|
||||
TERMINATED_ARCHIVED: {
|
||||
code: "TERMINATED_ARCHIVED",
|
||||
message: "This session was terminated because it was archived."
|
||||
},
|
||||
TERMINATED_BLOCKED: {
|
||||
code: "TERMINATED_BLOCKED",
|
||||
message: "This session was terminated because access was blocked."
|
||||
}
|
||||
} as const;
|
||||
|
||||
// Helper function to send registration error
|
||||
export async function sendOlmError(
|
||||
error: (typeof OlmErrorCodes)[keyof typeof OlmErrorCodes],
|
||||
olmId: string
|
||||
) {
|
||||
sendToClient(olmId, {
|
||||
type: "olm/error",
|
||||
data: {
|
||||
code: error.code,
|
||||
message: error.message
|
||||
}
|
||||
});
|
||||
}
|
||||
224
server/routers/olm/fingerprintingUtils.ts
Normal file
224
server/routers/olm/fingerprintingUtils.ts
Normal file
@@ -0,0 +1,224 @@
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { currentFingerprint, db, fingerprintSnapshots, Olm } from "@server/db";
|
||||
import { calculateCutoffTimestamp } from "@server/lib/cleanupLogs";
|
||||
import { desc, eq, lt } from "drizzle-orm";
|
||||
|
||||
function fingerprintSnapshotHash(fingerprint: any, postures: any): string {
|
||||
const canonical = {
|
||||
username: fingerprint.username ?? null,
|
||||
hostname: fingerprint.hostname ?? null,
|
||||
platform: fingerprint.platform ?? null,
|
||||
osVersion: fingerprint.osVersion ?? null,
|
||||
kernelVersion: fingerprint.kernelVersion ?? null,
|
||||
arch: fingerprint.arch ?? null,
|
||||
deviceModel: fingerprint.deviceModel ?? null,
|
||||
serialNumber: fingerprint.serialNumber ?? null,
|
||||
platformFingerprint: fingerprint.platformFingerprint ?? null,
|
||||
|
||||
biometricsEnabled: postures.biometricsEnabled ?? false,
|
||||
diskEncrypted: postures.diskEncrypted ?? false,
|
||||
firewallEnabled: postures.firewallEnabled ?? false,
|
||||
autoUpdatesEnabled: postures.autoUpdatesEnabled ?? false,
|
||||
tpmAvailable: postures.tpmAvailable ?? false,
|
||||
|
||||
windowsAntivirusEnabled: postures.windowsAntivirusEnabled ?? false,
|
||||
|
||||
macosSipEnabled: postures.macosSipEnabled ?? false,
|
||||
macosGatekeeperEnabled: postures.macosGatekeeperEnabled ?? false,
|
||||
macosFirewallStealthMode: postures.macosFirewallStealthMode ?? false,
|
||||
|
||||
linuxAppArmorEnabled: postures.linuxAppArmorEnabled ?? false,
|
||||
linuxSELinuxEnabled: postures.linuxSELinuxEnabled ?? false
|
||||
};
|
||||
|
||||
return encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(JSON.stringify(canonical)))
|
||||
);
|
||||
}
|
||||
|
||||
export async function handleFingerprintInsertion(
|
||||
olm: Olm,
|
||||
fingerprint: any,
|
||||
postures: any
|
||||
) {
|
||||
if (
|
||||
!olm?.olmId ||
|
||||
!fingerprint ||
|
||||
!postures ||
|
||||
Object.keys(fingerprint).length === 0 ||
|
||||
Object.keys(postures).length === 0
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
const hash = fingerprintSnapshotHash(fingerprint, postures);
|
||||
|
||||
const [current] = await db
|
||||
.select()
|
||||
.from(currentFingerprint)
|
||||
.where(eq(currentFingerprint.olmId, olm.olmId))
|
||||
.limit(1);
|
||||
|
||||
if (!current) {
|
||||
const [inserted] = await db
|
||||
.insert(currentFingerprint)
|
||||
.values({
|
||||
olmId: olm.olmId,
|
||||
firstSeen: now,
|
||||
lastSeen: now,
|
||||
lastCollectedAt: now,
|
||||
|
||||
// fingerprint
|
||||
username: fingerprint.username,
|
||||
hostname: fingerprint.hostname,
|
||||
platform: fingerprint.platform,
|
||||
osVersion: fingerprint.osVersion,
|
||||
kernelVersion: fingerprint.kernelVersion,
|
||||
arch: fingerprint.arch,
|
||||
deviceModel: fingerprint.deviceModel,
|
||||
serialNumber: fingerprint.serialNumber,
|
||||
platformFingerprint: fingerprint.platformFingerprint,
|
||||
|
||||
biometricsEnabled: postures.biometricsEnabled,
|
||||
diskEncrypted: postures.diskEncrypted,
|
||||
firewallEnabled: postures.firewallEnabled,
|
||||
autoUpdatesEnabled: postures.autoUpdatesEnabled,
|
||||
tpmAvailable: postures.tpmAvailable,
|
||||
|
||||
windowsAntivirusEnabled: postures.windowsAntivirusEnabled,
|
||||
|
||||
macosSipEnabled: postures.macosSipEnabled,
|
||||
macosGatekeeperEnabled: postures.macosGatekeeperEnabled,
|
||||
macosFirewallStealthMode: postures.macosFirewallStealthMode,
|
||||
|
||||
linuxAppArmorEnabled: postures.linuxAppArmorEnabled,
|
||||
linuxSELinuxEnabled: postures.linuxSELinuxEnabled
|
||||
})
|
||||
.returning();
|
||||
|
||||
await db.insert(fingerprintSnapshots).values({
|
||||
fingerprintId: inserted.fingerprintId,
|
||||
|
||||
username: fingerprint.username,
|
||||
hostname: fingerprint.hostname,
|
||||
platform: fingerprint.platform,
|
||||
osVersion: fingerprint.osVersion,
|
||||
kernelVersion: fingerprint.kernelVersion,
|
||||
arch: fingerprint.arch,
|
||||
deviceModel: fingerprint.deviceModel,
|
||||
serialNumber: fingerprint.serialNumber,
|
||||
platformFingerprint: fingerprint.platformFingerprint,
|
||||
|
||||
biometricsEnabled: postures.biometricsEnabled,
|
||||
diskEncrypted: postures.diskEncrypted,
|
||||
firewallEnabled: postures.firewallEnabled,
|
||||
autoUpdatesEnabled: postures.autoUpdatesEnabled,
|
||||
tpmAvailable: postures.tpmAvailable,
|
||||
|
||||
windowsAntivirusEnabled: postures.windowsAntivirusEnabled,
|
||||
|
||||
macosSipEnabled: postures.macosSipEnabled,
|
||||
macosGatekeeperEnabled: postures.macosGatekeeperEnabled,
|
||||
macosFirewallStealthMode: postures.macosFirewallStealthMode,
|
||||
|
||||
linuxAppArmorEnabled: postures.linuxAppArmorEnabled,
|
||||
linuxSELinuxEnabled: postures.linuxSELinuxEnabled,
|
||||
|
||||
hash,
|
||||
collectedAt: now
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
const [latestSnapshot] = await db
|
||||
.select({ hash: fingerprintSnapshots.hash })
|
||||
.from(fingerprintSnapshots)
|
||||
.where(eq(fingerprintSnapshots.fingerprintId, current.fingerprintId))
|
||||
.orderBy(desc(fingerprintSnapshots.collectedAt))
|
||||
.limit(1);
|
||||
|
||||
const changed = !latestSnapshot || latestSnapshot.hash !== hash;
|
||||
|
||||
if (changed) {
|
||||
await db.insert(fingerprintSnapshots).values({
|
||||
fingerprintId: current.fingerprintId,
|
||||
|
||||
username: fingerprint.username,
|
||||
hostname: fingerprint.hostname,
|
||||
platform: fingerprint.platform,
|
||||
osVersion: fingerprint.osVersion,
|
||||
kernelVersion: fingerprint.kernelVersion,
|
||||
arch: fingerprint.arch,
|
||||
deviceModel: fingerprint.deviceModel,
|
||||
serialNumber: fingerprint.serialNumber,
|
||||
platformFingerprint: fingerprint.platformFingerprint,
|
||||
|
||||
biometricsEnabled: postures.biometricsEnabled,
|
||||
diskEncrypted: postures.diskEncrypted,
|
||||
firewallEnabled: postures.firewallEnabled,
|
||||
autoUpdatesEnabled: postures.autoUpdatesEnabled,
|
||||
tpmAvailable: postures.tpmAvailable,
|
||||
|
||||
windowsAntivirusEnabled: postures.windowsAntivirusEnabled,
|
||||
|
||||
macosSipEnabled: postures.macosSipEnabled,
|
||||
macosGatekeeperEnabled: postures.macosGatekeeperEnabled,
|
||||
macosFirewallStealthMode: postures.macosFirewallStealthMode,
|
||||
|
||||
linuxAppArmorEnabled: postures.linuxAppArmorEnabled,
|
||||
linuxSELinuxEnabled: postures.linuxSELinuxEnabled,
|
||||
|
||||
hash,
|
||||
collectedAt: now
|
||||
});
|
||||
|
||||
await db
|
||||
.update(currentFingerprint)
|
||||
.set({
|
||||
lastSeen: now,
|
||||
lastCollectedAt: now,
|
||||
|
||||
username: fingerprint.username,
|
||||
hostname: fingerprint.hostname,
|
||||
platform: fingerprint.platform,
|
||||
osVersion: fingerprint.osVersion,
|
||||
kernelVersion: fingerprint.kernelVersion,
|
||||
arch: fingerprint.arch,
|
||||
deviceModel: fingerprint.deviceModel,
|
||||
serialNumber: fingerprint.serialNumber,
|
||||
platformFingerprint: fingerprint.platformFingerprint,
|
||||
|
||||
biometricsEnabled: postures.biometricsEnabled,
|
||||
diskEncrypted: postures.diskEncrypted,
|
||||
firewallEnabled: postures.firewallEnabled,
|
||||
autoUpdatesEnabled: postures.autoUpdatesEnabled,
|
||||
tpmAvailable: postures.tpmAvailable,
|
||||
|
||||
windowsAntivirusEnabled: postures.windowsAntivirusEnabled,
|
||||
|
||||
macosSipEnabled: postures.macosSipEnabled,
|
||||
macosGatekeeperEnabled: postures.macosGatekeeperEnabled,
|
||||
macosFirewallStealthMode: postures.macosFirewallStealthMode,
|
||||
|
||||
linuxAppArmorEnabled: postures.linuxAppArmorEnabled,
|
||||
linuxSELinuxEnabled: postures.linuxSELinuxEnabled
|
||||
})
|
||||
.where(eq(currentFingerprint.fingerprintId, current.fingerprintId));
|
||||
} else {
|
||||
await db
|
||||
.update(currentFingerprint)
|
||||
.set({ lastSeen: now })
|
||||
.where(eq(currentFingerprint.fingerprintId, current.fingerprintId));
|
||||
}
|
||||
}
|
||||
|
||||
export async function cleanUpOldFingerprintSnapshots(retentionDays: number) {
|
||||
const cutoff = calculateCutoffTimestamp(retentionDays);
|
||||
|
||||
await db
|
||||
.delete(fingerprintSnapshots)
|
||||
.where(lt(fingerprintSnapshots.collectedAt, cutoff));
|
||||
}
|
||||
@@ -1,11 +1,14 @@
|
||||
import { generateSessionToken } from "@server/auth/sessions/app";
|
||||
import {
|
||||
generateSessionToken,
|
||||
validateSessionToken
|
||||
} from "@server/auth/sessions/app";
|
||||
import {
|
||||
clients,
|
||||
db,
|
||||
ExitNode,
|
||||
exitNodes,
|
||||
sites,
|
||||
clientSitesAssociationsCache
|
||||
clientSitesAssociationsCache,
|
||||
} from "@server/db";
|
||||
import { olms } from "@server/db";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
@@ -17,8 +20,10 @@ import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import {
|
||||
createOlmSession,
|
||||
validateOlmSessionToken
|
||||
validateOlmSessionToken,
|
||||
EXPIRES
|
||||
} from "@server/auth/sessions/olm";
|
||||
import { getOrCreateCachedToken } from "#dynamic/lib/tokenCache";
|
||||
import { verifyPassword } from "@server/auth/password";
|
||||
import logger from "@server/logger";
|
||||
import config from "@server/lib/config";
|
||||
@@ -26,8 +31,9 @@ import { APP_VERSION } from "@server/lib/consts";
|
||||
|
||||
export const olmGetTokenBodySchema = z.object({
|
||||
olmId: z.string(),
|
||||
secret: z.string(),
|
||||
token: z.string().optional(),
|
||||
secret: z.string().optional(),
|
||||
userToken: z.string().optional(),
|
||||
token: z.string().optional(), // this is the olm token
|
||||
orgId: z.string().optional()
|
||||
});
|
||||
|
||||
@@ -49,7 +55,7 @@ export async function getOlmToken(
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId, secret, token, orgId } = parsedBody.data;
|
||||
const { olmId, secret, token, orgId, userToken } = parsedBody.data;
|
||||
|
||||
try {
|
||||
if (token) {
|
||||
@@ -84,26 +90,63 @@ export async function getOlmToken(
|
||||
);
|
||||
}
|
||||
|
||||
const validSecret = await verifyPassword(
|
||||
secret,
|
||||
existingOlm.secretHash
|
||||
);
|
||||
|
||||
if (!validSecret) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
`Olm id or secret is incorrect. Olm: ID ${olmId}. IP: ${req.ip}.`
|
||||
if (userToken) {
|
||||
const { session: userSession, user } =
|
||||
await validateSessionToken(userToken);
|
||||
if (!userSession || !user) {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Invalid user token")
|
||||
);
|
||||
}
|
||||
if (user.userId !== existingOlm.userId) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"User token does not match olm"
|
||||
)
|
||||
);
|
||||
}
|
||||
} else if (secret) {
|
||||
// this is for backward compatibility, we want to move towards userToken but some old clients may still be using secret so we will support both for now
|
||||
const validSecret = await verifyPassword(
|
||||
secret,
|
||||
existingOlm.secretHash
|
||||
);
|
||||
|
||||
if (!validSecret) {
|
||||
if (config.getRawConfig().app.log_failed_attempts) {
|
||||
logger.info(
|
||||
`Olm id or secret is incorrect. Olm: ID ${olmId}. IP: ${req.ip}.`
|
||||
);
|
||||
}
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect")
|
||||
);
|
||||
}
|
||||
} else {
|
||||
return next(
|
||||
createHttpError(HttpCode.BAD_REQUEST, "Secret is incorrect")
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"Either secret or userToken is required"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
logger.debug("Creating new olm session token");
|
||||
|
||||
const resToken = generateSessionToken();
|
||||
await createOlmSession(resToken, existingOlm.olmId);
|
||||
// Return a cached token if one exists to prevent thundering herd on
|
||||
// simultaneous restarts; falls back to creating a fresh session when
|
||||
// Redis is unavailable or the cache has expired.
|
||||
const resToken = await getOrCreateCachedToken(
|
||||
`olm:token_cache:${existingOlm.olmId}`,
|
||||
config.getRawConfig().server.secret!,
|
||||
Math.floor(EXPIRES / 1000),
|
||||
async () => {
|
||||
const token = generateSessionToken();
|
||||
await createOlmSession(token, existingOlm.olmId);
|
||||
return token;
|
||||
}
|
||||
);
|
||||
|
||||
let clientIdToUse;
|
||||
if (orgId) {
|
||||
@@ -194,10 +237,23 @@ export async function getOlmToken(
|
||||
.where(inArray(exitNodes.exitNodeId, exitNodeIds));
|
||||
}
|
||||
|
||||
// Map exitNodeId to siteIds
|
||||
const exitNodeIdToSiteIds: Record<number, number[]> = {};
|
||||
for (const { sites: site } of clientSites) {
|
||||
if (site.exitNodeId !== null) {
|
||||
if (!exitNodeIdToSiteIds[site.exitNodeId]) {
|
||||
exitNodeIdToSiteIds[site.exitNodeId] = [];
|
||||
}
|
||||
exitNodeIdToSiteIds[site.exitNodeId].push(site.siteId);
|
||||
}
|
||||
}
|
||||
|
||||
const exitNodesHpData = allExitNodes.map((exitNode: ExitNode) => {
|
||||
return {
|
||||
publicKey: exitNode.publicKey,
|
||||
endpoint: exitNode.endpoint
|
||||
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
||||
endpoint: exitNode.endpoint,
|
||||
siteIds: exitNodeIdToSiteIds[exitNode.exitNodeId] ?? []
|
||||
};
|
||||
});
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db } from "@server/db";
|
||||
import { olms } from "@server/db";
|
||||
import { olms, clients, currentFingerprint } from "@server/db";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -8,7 +8,8 @@ import response from "@server/lib/response";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { getUserDeviceName } from "@server/db/names";
|
||||
// import { OpenAPITags, registry } from "@server/openApi";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
@@ -17,6 +18,10 @@ const paramsSchema = z
|
||||
})
|
||||
.strict();
|
||||
|
||||
const querySchema = z.object({
|
||||
orgId: z.string().optional()
|
||||
});
|
||||
|
||||
// registry.registerPath({
|
||||
// method: "get",
|
||||
// path: "/user/{userId}/olm/{olmId}",
|
||||
@@ -44,15 +49,63 @@ export async function getUserOlm(
|
||||
);
|
||||
}
|
||||
|
||||
const { olmId, userId } = parsedParams.data;
|
||||
const parsedQuery = querySchema.safeParse(req.query);
|
||||
if (!parsedQuery.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedQuery.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [olm] = await db
|
||||
const { olmId, userId } = parsedParams.data;
|
||||
const { orgId } = parsedQuery.data;
|
||||
|
||||
const [result] = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(and(eq(olms.userId, userId), eq(olms.olmId, olmId)));
|
||||
.where(and(eq(olms.userId, userId), eq(olms.olmId, olmId)))
|
||||
.leftJoin(
|
||||
currentFingerprint,
|
||||
eq(olms.olmId, currentFingerprint.olmId)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!result || !result.olms) {
|
||||
return next(createHttpError(HttpCode.NOT_FOUND, "Olm not found"));
|
||||
}
|
||||
|
||||
const olm = result.olms;
|
||||
|
||||
// If orgId is provided and olm has a clientId, fetch the client to check blocked status
|
||||
let blocked: boolean | undefined;
|
||||
if (orgId && olm.clientId) {
|
||||
const [client] = await db
|
||||
.select({ blocked: clients.blocked })
|
||||
.from(clients)
|
||||
.where(
|
||||
and(
|
||||
eq(clients.clientId, olm.clientId),
|
||||
eq(clients.orgId, orgId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
blocked = client?.blocked ?? false;
|
||||
}
|
||||
|
||||
// Replace name with device name
|
||||
const model = result.currentFingerprint?.deviceModel || null;
|
||||
const newName = getUserDeviceName(model, olm.name);
|
||||
|
||||
const responseData =
|
||||
blocked !== undefined
|
||||
? { ...olm, name: newName, blocked }
|
||||
: { ...olm, name: newName };
|
||||
|
||||
return response(res, {
|
||||
data: olm,
|
||||
data: responseData,
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Successfully retrieved olm",
|
||||
|
||||
34
server/routers/olm/handleOlmDisconnectingMessage.ts
Normal file
34
server/routers/olm/handleOlmDisconnectingMessage.ts
Normal file
@@ -0,0 +1,34 @@
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { clients, db, Olm } from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
|
||||
/**
|
||||
* Handles disconnecting messages from clients to show disconnected in the ui
|
||||
*/
|
||||
export const handleOlmDisconnectingMessage: MessageHandler = async (context) => {
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const olm = c as Olm;
|
||||
|
||||
if (!olm) {
|
||||
logger.warn("Olm not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!olm.clientId) {
|
||||
logger.warn("Olm has no client ID!");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Update the client's last ping timestamp
|
||||
await db
|
||||
.update(clients)
|
||||
.set({
|
||||
online: false
|
||||
})
|
||||
.where(eq(clients.clientId, olm.clientId));
|
||||
} catch (error) {
|
||||
logger.error("Error handling disconnecting message", { error });
|
||||
}
|
||||
};
|
||||
@@ -1,14 +1,18 @@
|
||||
import { disconnectClient, getClientConfigVersion } from "#dynamic/routers/ws";
|
||||
import { db } from "@server/db";
|
||||
import { disconnectClient } from "#dynamic/routers/ws";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { clients, Olm } from "@server/db";
|
||||
import { clients, olms, Olm } from "@server/db";
|
||||
import { eq, lt, isNull, and, or } from "drizzle-orm";
|
||||
import { recordClientPing } from "@server/routers/newt/pingAccumulator";
|
||||
import logger from "@server/logger";
|
||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
import { sendTerminateClient } from "../client/terminate";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { sendOlmSyncMessage } from "./sync";
|
||||
import { OlmErrorCodes } from "./error";
|
||||
import { handleFingerprintInsertion } from "./fingerprintingUtils";
|
||||
|
||||
// Track if the offline checker interval is running
|
||||
let offlineCheckerInterval: NodeJS.Timeout | null = null;
|
||||
@@ -63,6 +67,7 @@ export const startOlmOfflineChecker = (): void => {
|
||||
try {
|
||||
await sendTerminateClient(
|
||||
offlineClient.clientId,
|
||||
OlmErrorCodes.TERMINATED_INACTIVITY,
|
||||
offlineClient.olmId
|
||||
); // terminate first
|
||||
// wait a moment to ensure the message is sent
|
||||
@@ -101,79 +106,116 @@ export const handleOlmPingMessage: MessageHandler = async (context) => {
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const olm = c as Olm;
|
||||
|
||||
const { userToken } = message.data;
|
||||
const { userToken, fingerprint, postures } = message.data;
|
||||
|
||||
if (!olm) {
|
||||
logger.warn("Olm not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (olm.userId) {
|
||||
// we need to check a user token to make sure its still valid
|
||||
const { session: userSession, user } =
|
||||
await validateSessionToken(userToken);
|
||||
if (!userSession || !user) {
|
||||
logger.warn("Invalid user session for olm ping");
|
||||
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
|
||||
}
|
||||
if (user.userId !== olm.userId) {
|
||||
logger.warn("User ID mismatch for olm ping");
|
||||
return;
|
||||
}
|
||||
|
||||
// get the client
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(
|
||||
and(
|
||||
eq(clients.olmId, olm.olmId),
|
||||
eq(clients.userId, olm.userId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
logger.warn("Client not found for olm ping");
|
||||
return;
|
||||
}
|
||||
|
||||
const sessionId = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(userToken))
|
||||
);
|
||||
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: client.orgId,
|
||||
userId: olm.userId,
|
||||
sessionId // this is the user token passed in the message
|
||||
});
|
||||
|
||||
if (!policyCheck.allowed) {
|
||||
logger.warn(
|
||||
`Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (!olm.clientId) {
|
||||
logger.warn("Olm has no client ID!");
|
||||
return;
|
||||
}
|
||||
|
||||
const isUserDevice = olm.userId !== null && olm.userId !== undefined;
|
||||
|
||||
try {
|
||||
// Update the client's last ping timestamp
|
||||
await db
|
||||
.update(clients)
|
||||
.set({
|
||||
lastPing: Math.floor(Date.now() / 1000),
|
||||
online: true
|
||||
})
|
||||
.where(eq(clients.clientId, olm.clientId));
|
||||
// get the client
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, olm.clientId))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
logger.warn("Client not found for olm ping");
|
||||
return;
|
||||
}
|
||||
|
||||
if (client.blocked) {
|
||||
// NOTE: by returning we dont update the lastPing, so the offline checker will eventually disconnect them
|
||||
logger.debug(
|
||||
`Blocked client ${client.clientId} attempted olm ping`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (olm.userId) {
|
||||
// we need to check a user token to make sure its still valid
|
||||
const { session: userSession, user } =
|
||||
await validateSessionToken(userToken);
|
||||
if (!userSession || !user) {
|
||||
logger.warn("Invalid user session for olm ping");
|
||||
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
|
||||
}
|
||||
if (user.userId !== olm.userId) {
|
||||
logger.warn("User ID mismatch for olm ping");
|
||||
return;
|
||||
}
|
||||
if (user.userId !== client.userId) {
|
||||
logger.warn("Client user ID mismatch for olm ping");
|
||||
return;
|
||||
}
|
||||
|
||||
const sessionId = encodeHexLowerCase(
|
||||
sha256(new TextEncoder().encode(userToken))
|
||||
);
|
||||
|
||||
const policyCheck = await checkOrgAccessPolicy({
|
||||
orgId: client.orgId,
|
||||
userId: olm.userId,
|
||||
sessionId // this is the user token passed in the message
|
||||
});
|
||||
|
||||
if (!policyCheck.allowed) {
|
||||
logger.warn(
|
||||
`Olm user ${olm.userId} does not pass access policies for org ${client.orgId}: ${policyCheck.error}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// get the version
|
||||
logger.debug(
|
||||
`handleOlmPingMessage: About to get config version for olmId: ${olm.olmId}`
|
||||
);
|
||||
const configVersion = await getClientConfigVersion(olm.olmId);
|
||||
logger.debug(
|
||||
`handleOlmPingMessage: Got config version: ${configVersion} (type: ${typeof configVersion})`
|
||||
);
|
||||
|
||||
if (configVersion == null || configVersion === undefined) {
|
||||
logger.debug(
|
||||
`handleOlmPingMessage: could not get config version from server for olmId: ${olm.olmId}`
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
message.configVersion != null &&
|
||||
configVersion != null &&
|
||||
configVersion != message.configVersion
|
||||
) {
|
||||
logger.debug(
|
||||
`handleOlmPingMessage: Olm ping with outdated config version: ${message.configVersion} (current: ${configVersion})`
|
||||
);
|
||||
await sendOlmSyncMessage(olm, client);
|
||||
}
|
||||
|
||||
// Record the ping in memory; it will be flushed to the database
|
||||
// periodically by the ping accumulator (every ~10s) in a single
|
||||
// batched UPDATE instead of one query per ping. This prevents
|
||||
// connection pool exhaustion under load, especially with
|
||||
// cross-region latency to the database.
|
||||
recordClientPing(olm.clientId, olm.olmId, !!olm.archived);
|
||||
} catch (error) {
|
||||
logger.error("Error handling ping message", { error });
|
||||
}
|
||||
|
||||
if (isUserDevice) {
|
||||
await handleFingerprintInsertion(olm, fingerprint, postures);
|
||||
}
|
||||
|
||||
return {
|
||||
message: {
|
||||
type: "pong",
|
||||
|
||||
@@ -1,28 +1,25 @@
|
||||
import {
|
||||
clientSiteResourcesAssociationsCache,
|
||||
db,
|
||||
orgs,
|
||||
siteResources
|
||||
} from "@server/db";
|
||||
import { db, orgs } from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import {
|
||||
clients,
|
||||
clientSitesAssociationsCache,
|
||||
exitNodes,
|
||||
Olm,
|
||||
olms,
|
||||
sites
|
||||
} from "@server/db";
|
||||
import { and, eq, inArray, isNull } from "drizzle-orm";
|
||||
import { addPeer, deletePeer } from "../newt/peers";
|
||||
import { count, eq } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { generateAliasConfig } from "@server/lib/ip";
|
||||
import { generateRemoteSubnets } from "@server/lib/ip";
|
||||
import { checkOrgAccessPolicy } from "#dynamic/lib/checkOrgAccessPolicy";
|
||||
import { validateSessionToken } from "@server/auth/sessions/app";
|
||||
import config from "@server/lib/config";
|
||||
import { encodeHexLowerCase } from "@oslojs/encoding";
|
||||
import { sha256 } from "@oslojs/crypto/sha2";
|
||||
import { getUserDeviceName } from "@server/db/names";
|
||||
import { buildSiteConfigurationForOlmClient } from "./buildConfiguration";
|
||||
import { OlmErrorCodes, sendOlmError } from "./error";
|
||||
import { handleFingerprintInsertion } from "./fingerprintingUtils";
|
||||
import { Alias } from "@server/lib/ip";
|
||||
import { build } from "@server/build";
|
||||
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||
|
||||
export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
logger.info("Handling register olm message!");
|
||||
@@ -36,14 +33,51 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const { publicKey, relay, olmVersion, olmAgent, orgId, userToken } =
|
||||
message.data;
|
||||
const {
|
||||
publicKey,
|
||||
relay,
|
||||
olmVersion,
|
||||
olmAgent,
|
||||
orgId,
|
||||
userToken,
|
||||
fingerprint,
|
||||
postures,
|
||||
chainId
|
||||
} = message.data;
|
||||
|
||||
if (!olm.clientId) {
|
||||
logger.warn("Olm client ID not found");
|
||||
sendOlmError(OlmErrorCodes.CLIENT_ID_NOT_FOUND, olm.olmId);
|
||||
return;
|
||||
}
|
||||
|
||||
logger.debug("Handling fingerprint insertion for olm register...", {
|
||||
olmId: olm.olmId,
|
||||
fingerprint,
|
||||
postures
|
||||
});
|
||||
|
||||
const isUserDevice = olm.userId !== null && olm.userId !== undefined;
|
||||
|
||||
if (isUserDevice) {
|
||||
await handleFingerprintInsertion(olm, fingerprint, postures);
|
||||
}
|
||||
|
||||
if (
|
||||
(olmVersion && olm.version !== olmVersion) ||
|
||||
(olmAgent && olm.agent !== olmAgent) ||
|
||||
olm.archived
|
||||
) {
|
||||
await db
|
||||
.update(olms)
|
||||
.set({
|
||||
version: olmVersion,
|
||||
agent: olmAgent,
|
||||
archived: false
|
||||
})
|
||||
.where(eq(olms.olmId, olm.olmId));
|
||||
}
|
||||
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
@@ -52,9 +86,41 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
|
||||
if (!client) {
|
||||
logger.warn("Client ID not found");
|
||||
sendOlmError(OlmErrorCodes.CLIENT_NOT_FOUND, olm.olmId);
|
||||
return;
|
||||
}
|
||||
|
||||
if (client.blocked) {
|
||||
logger.debug(
|
||||
`Client ${client.clientId} is blocked. Ignoring register.`
|
||||
);
|
||||
sendOlmError(OlmErrorCodes.CLIENT_BLOCKED, olm.olmId);
|
||||
return;
|
||||
}
|
||||
|
||||
if (client.approvalState == "pending") {
|
||||
logger.debug(
|
||||
`Client ${client.clientId} approval is pending. Ignoring register.`
|
||||
);
|
||||
sendOlmError(OlmErrorCodes.CLIENT_PENDING, olm.olmId);
|
||||
return;
|
||||
}
|
||||
|
||||
const deviceModel = fingerprint?.deviceModel ?? null;
|
||||
const computedName = getUserDeviceName(deviceModel, client.name);
|
||||
if (computedName && computedName !== client.name) {
|
||||
await db
|
||||
.update(clients)
|
||||
.set({ name: computedName })
|
||||
.where(eq(clients.clientId, client.clientId));
|
||||
}
|
||||
if (computedName && computedName !== olm.name) {
|
||||
await db
|
||||
.update(olms)
|
||||
.set({ name: computedName })
|
||||
.where(eq(olms.olmId, olm.olmId));
|
||||
}
|
||||
|
||||
const [org] = await db
|
||||
.select()
|
||||
.from(orgs)
|
||||
@@ -63,12 +129,14 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
|
||||
if (!org) {
|
||||
logger.warn("Org not found");
|
||||
sendOlmError(OlmErrorCodes.ORG_NOT_FOUND, olm.olmId);
|
||||
return;
|
||||
}
|
||||
|
||||
if (orgId) {
|
||||
if (!olm.userId) {
|
||||
logger.warn("Olm has no user ID");
|
||||
sendOlmError(OlmErrorCodes.USER_ID_NOT_FOUND, olm.olmId);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -76,10 +144,12 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
await validateSessionToken(userToken);
|
||||
if (!userSession || !user) {
|
||||
logger.warn("Invalid user session for olm register");
|
||||
return; // by returning here we just ignore the ping and the setInterval will force it to disconnect
|
||||
sendOlmError(OlmErrorCodes.INVALID_USER_SESSION, olm.olmId);
|
||||
return;
|
||||
}
|
||||
if (user.userId !== olm.userId) {
|
||||
logger.warn("User ID mismatch for olm register");
|
||||
sendOlmError(OlmErrorCodes.USER_ID_MISMATCH, olm.olmId);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -93,14 +163,80 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
sessionId // this is the user token passed in the message
|
||||
});
|
||||
|
||||
if (!policyCheck.allowed) {
|
||||
logger.debug("Policy check result:", policyCheck);
|
||||
|
||||
if (policyCheck?.error) {
|
||||
logger.error(
|
||||
`Error checking access policies for olm user ${olm.userId} in org ${orgId}: ${policyCheck?.error}`
|
||||
);
|
||||
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
|
||||
return;
|
||||
}
|
||||
|
||||
if (policyCheck.policies?.passwordAge?.compliant === false) {
|
||||
logger.warn(
|
||||
`Olm user ${olm.userId} has non-compliant password age for org ${orgId}`
|
||||
);
|
||||
sendOlmError(
|
||||
OlmErrorCodes.ORG_ACCESS_POLICY_PASSWORD_EXPIRED,
|
||||
olm.olmId
|
||||
);
|
||||
return;
|
||||
} else if (
|
||||
policyCheck.policies?.maxSessionLength?.compliant === false
|
||||
) {
|
||||
logger.warn(
|
||||
`Olm user ${olm.userId} has non-compliant session length for org ${orgId}`
|
||||
);
|
||||
sendOlmError(
|
||||
OlmErrorCodes.ORG_ACCESS_POLICY_SESSION_EXPIRED,
|
||||
olm.olmId
|
||||
);
|
||||
return;
|
||||
} else if (policyCheck.policies?.requiredTwoFactor === false) {
|
||||
logger.warn(
|
||||
`Olm user ${olm.userId} does not have 2FA enabled for org ${orgId}`
|
||||
);
|
||||
sendOlmError(
|
||||
OlmErrorCodes.ORG_ACCESS_POLICY_2FA_REQUIRED,
|
||||
olm.olmId
|
||||
);
|
||||
return;
|
||||
} else if (!policyCheck.allowed) {
|
||||
logger.warn(
|
||||
`Olm user ${olm.userId} does not pass access policies for org ${orgId}: ${policyCheck.error}`
|
||||
);
|
||||
sendOlmError(OlmErrorCodes.ORG_ACCESS_POLICY_DENIED, olm.olmId);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Get all sites data
|
||||
const sitesCountResult = await db
|
||||
.select({ count: count() })
|
||||
.from(sites)
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||
|
||||
// Extract the count value from the result array
|
||||
const sitesCount =
|
||||
sitesCountResult.length > 0 ? sitesCountResult[0].count : 0;
|
||||
|
||||
// Prepare an array to store site configurations
|
||||
logger.debug(`Found ${sitesCount} sites for client ${client.clientId}`);
|
||||
|
||||
let jitMode = false;
|
||||
if (sitesCount > 250 && build == "saas") {
|
||||
// THIS IS THE MAX ON THE BUSINESS TIER
|
||||
// we have too many sites
|
||||
// If we have too many sites we need to drop into fully JIT mode by not sending any of the sites
|
||||
logger.info("Too many sites (%d), dropping into JIT mode", sitesCount);
|
||||
jitMode = true;
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
`Olm client ID: ${client.clientId}, Public Key: ${publicKey}, Relay: ${relay}`
|
||||
);
|
||||
@@ -110,20 +246,7 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
(olmVersion && olm.version !== olmVersion) ||
|
||||
(olmAgent && olm.agent !== olmAgent)
|
||||
) {
|
||||
await db
|
||||
.update(olms)
|
||||
.set({
|
||||
version: olmVersion,
|
||||
agent: olmAgent
|
||||
})
|
||||
.where(eq(olms.olmId, olm.olmId));
|
||||
}
|
||||
|
||||
if (client.pubKey !== publicKey) {
|
||||
if (client.pubKey !== publicKey || client.archived) {
|
||||
logger.info(
|
||||
"Public key mismatch. Updating public key and clearing session info..."
|
||||
);
|
||||
@@ -131,7 +254,8 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
await db
|
||||
.update(clients)
|
||||
.set({
|
||||
pubKey: publicKey
|
||||
pubKey: publicKey,
|
||||
archived: false
|
||||
})
|
||||
.where(eq(clients.clientId, client.clientId));
|
||||
|
||||
@@ -139,161 +263,29 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
await db
|
||||
.update(clientSitesAssociationsCache)
|
||||
.set({
|
||||
isRelayed: relay == true
|
||||
isRelayed: relay == true,
|
||||
isJitMode: jitMode
|
||||
})
|
||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||
}
|
||||
|
||||
// Get all sites data
|
||||
const sitesData = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.innerJoin(
|
||||
clientSitesAssociationsCache,
|
||||
eq(sites.siteId, clientSitesAssociationsCache.siteId)
|
||||
)
|
||||
.where(eq(clientSitesAssociationsCache.clientId, client.clientId));
|
||||
|
||||
// Prepare an array to store site configurations
|
||||
const siteConfigurations = [];
|
||||
logger.debug(
|
||||
`Found ${sitesData.length} sites for client ${client.clientId}`
|
||||
);
|
||||
|
||||
// this prevents us from accepting a register from an olm that has not hole punched yet.
|
||||
// the olm will pump the register so we can keep checking
|
||||
// TODO: I still think there is a better way to do this rather than locking it out here but ???
|
||||
if (now - (client.lastHolePunch || 0) > 5 && sitesData.length > 0) {
|
||||
if (now - (client.lastHolePunch || 0) > 5 && sitesCount > 0) {
|
||||
logger.warn(
|
||||
"Client last hole punch is too old and we have sites to send; skipping this register"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Process each site
|
||||
for (const {
|
||||
sites: site,
|
||||
clientSitesAssociationsCache: association
|
||||
} of sitesData) {
|
||||
if (!site.exitNodeId) {
|
||||
logger.warn(
|
||||
`Site ${site.siteId} does not have exit node, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate endpoint and hole punch status
|
||||
if (!site.endpoint) {
|
||||
logger.warn(
|
||||
`In olm register: site ${site.siteId} has no endpoint, skipping`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// if (site.lastHolePunch && now - site.lastHolePunch > 6 && relay) {
|
||||
// logger.warn(
|
||||
// `Site ${site.siteId} last hole punch is too old, skipping`
|
||||
// );
|
||||
// continue;
|
||||
// }
|
||||
|
||||
// If public key changed, delete old peer from this site
|
||||
if (client.pubKey && client.pubKey != publicKey) {
|
||||
logger.info(
|
||||
`Public key mismatch. Deleting old peer from site ${site.siteId}...`
|
||||
);
|
||||
await deletePeer(site.siteId, client.pubKey!);
|
||||
}
|
||||
|
||||
if (!site.subnet) {
|
||||
logger.warn(`Site ${site.siteId} has no subnet, skipping`);
|
||||
continue;
|
||||
}
|
||||
|
||||
const [clientSite] = await db
|
||||
.select()
|
||||
.from(clientSitesAssociationsCache)
|
||||
.where(
|
||||
and(
|
||||
eq(clientSitesAssociationsCache.clientId, client.clientId),
|
||||
eq(clientSitesAssociationsCache.siteId, site.siteId)
|
||||
)
|
||||
)
|
||||
.limit(1);
|
||||
|
||||
// Add the peer to the exit node for this site
|
||||
if (clientSite.endpoint) {
|
||||
logger.info(
|
||||
`Adding peer ${publicKey} to site ${site.siteId} with endpoint ${clientSite.endpoint}`
|
||||
);
|
||||
await addPeer(site.siteId, {
|
||||
publicKey: publicKey,
|
||||
allowedIps: [`${client.subnet.split("/")[0]}/32`], // we want to only allow from that client
|
||||
endpoint: relay ? "" : clientSite.endpoint
|
||||
});
|
||||
} else {
|
||||
logger.warn(
|
||||
`Client ${client.clientId} has no endpoint, skipping peer addition`
|
||||
);
|
||||
}
|
||||
|
||||
let relayEndpoint: string | undefined = undefined;
|
||||
if (relay) {
|
||||
const [exitNode] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.exitNodeId, site.exitNodeId))
|
||||
.limit(1);
|
||||
if (!exitNode) {
|
||||
logger.warn(`Exit node not found for site ${site.siteId}`);
|
||||
continue;
|
||||
}
|
||||
relayEndpoint = `${exitNode.endpoint}:${config.getRawConfig().gerbil.clients_start_port}`;
|
||||
}
|
||||
|
||||
const allSiteResources = await db // only get the site resources that this client has access to
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.innerJoin(
|
||||
clientSiteResourcesAssociationsCache,
|
||||
eq(
|
||||
siteResources.siteResourceId,
|
||||
clientSiteResourcesAssociationsCache.siteResourceId
|
||||
)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(siteResources.siteId, site.siteId),
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.clientId,
|
||||
client.clientId
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
// Add site configuration to the array
|
||||
siteConfigurations.push({
|
||||
siteId: site.siteId,
|
||||
name: site.name,
|
||||
// relayEndpoint: relayEndpoint, // this can be undefined now if not relayed // lets not do this for now because it would conflict with the hole punch testing
|
||||
endpoint: site.endpoint,
|
||||
publicKey: site.publicKey,
|
||||
serverIP: site.address,
|
||||
serverPort: site.listenPort,
|
||||
remoteSubnets: generateRemoteSubnets(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
),
|
||||
aliases: generateAliasConfig(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
// REMOVED THIS SO IT CREATES THE INTERFACE AND JUST WAITS FOR THE SITES
|
||||
// if (siteConfigurations.length === 0) {
|
||||
// logger.warn("No valid site configurations found");
|
||||
// return;
|
||||
// }
|
||||
// NOTE: its important that the client here is the old client and the public key is the new key
|
||||
const siteConfigurations = await buildSiteConfigurationForOlmClient(
|
||||
client,
|
||||
publicKey,
|
||||
relay,
|
||||
jitMode
|
||||
);
|
||||
|
||||
// Return connect message with all site configurations
|
||||
return {
|
||||
@@ -302,9 +294,13 @@ export const handleOlmRegisterMessage: MessageHandler = async (context) => {
|
||||
data: {
|
||||
sites: siteConfigurations,
|
||||
tunnelIP: client.subnet,
|
||||
utilitySubnet: org.utilitySubnet
|
||||
utilitySubnet: org.utilitySubnet,
|
||||
chainId: chainId
|
||||
}
|
||||
},
|
||||
options: {
|
||||
compress: canCompress(olm.version, "olm")
|
||||
},
|
||||
broadcast: false,
|
||||
excludeSender: false
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@ import { clients, clientSitesAssociationsCache, Olm } from "@server/db";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { updatePeer as newtUpdatePeer } from "../newt/peers";
|
||||
import logger from "@server/logger";
|
||||
import config from "@server/lib/config";
|
||||
|
||||
export const handleOlmRelayMessage: MessageHandler = async (context) => {
|
||||
const { message, client: c, sendToClient } = context;
|
||||
@@ -17,7 +18,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
|
||||
}
|
||||
|
||||
if (!olm.clientId) {
|
||||
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
|
||||
logger.warn("Olm has no client!");
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -40,7 +41,7 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const { siteId } = message.data;
|
||||
const { siteId, chainId } = message.data;
|
||||
|
||||
// Get the site
|
||||
const [site] = await db
|
||||
@@ -88,7 +89,9 @@ export const handleOlmRelayMessage: MessageHandler = async (context) => {
|
||||
type: "olm/wg/peer/relay",
|
||||
data: {
|
||||
siteId: siteId,
|
||||
relayEndpoint: exitNode.endpoint
|
||||
relayEndpoint: exitNode.endpoint,
|
||||
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
||||
chainId
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
|
||||
241
server/routers/olm/handleOlmServerInitAddPeerHandshake.ts
Normal file
241
server/routers/olm/handleOlmServerInitAddPeerHandshake.ts
Normal file
@@ -0,0 +1,241 @@
|
||||
import {
|
||||
clientSiteResourcesAssociationsCache,
|
||||
clientSitesAssociationsCache,
|
||||
db,
|
||||
exitNodes,
|
||||
Site,
|
||||
siteResources
|
||||
} from "@server/db";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { clients, Olm, sites } from "@server/db";
|
||||
import { and, eq, or } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { initPeerAddHandshake } from "./peers";
|
||||
|
||||
export const handleOlmServerInitAddPeerHandshake: MessageHandler = async (
|
||||
context
|
||||
) => {
|
||||
logger.info("Handling register olm message!");
|
||||
const { message, client: c, sendToClient } = context;
|
||||
const olm = c as Olm;
|
||||
|
||||
if (!olm) {
|
||||
logger.warn("Olm not found");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!olm.clientId) {
|
||||
logger.warn("Olm has no client!"); // TODO: Maybe we create the site here?
|
||||
return;
|
||||
}
|
||||
|
||||
const clientId = olm.clientId;
|
||||
|
||||
const [client] = await db
|
||||
.select()
|
||||
.from(clients)
|
||||
.where(eq(clients.clientId, clientId))
|
||||
.limit(1);
|
||||
|
||||
if (!client) {
|
||||
logger.warn("Client not found");
|
||||
return;
|
||||
}
|
||||
|
||||
const { siteId, resourceId, chainId } = message.data;
|
||||
|
||||
let site: Site | null = null;
|
||||
if (siteId) {
|
||||
// get the site
|
||||
const [siteRes] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteId))
|
||||
.limit(1);
|
||||
if (siteRes) {
|
||||
site = siteRes;
|
||||
}
|
||||
}
|
||||
|
||||
if (resourceId && !site) {
|
||||
const resources = await db
|
||||
.select()
|
||||
.from(siteResources)
|
||||
.where(
|
||||
and(
|
||||
or(
|
||||
eq(siteResources.niceId, resourceId),
|
||||
eq(siteResources.alias, resourceId)
|
||||
),
|
||||
eq(siteResources.orgId, client.orgId)
|
||||
)
|
||||
);
|
||||
|
||||
if (!resources || resources.length === 0) {
|
||||
logger.error(`handleOlmServerPeerAddMessage: Resource not found`);
|
||||
// cancel the request from the olm side to not keep doing this
|
||||
await sendToClient(
|
||||
olm.olmId,
|
||||
{
|
||||
type: "olm/wg/peer/chain/cancel",
|
||||
data: {
|
||||
chainId
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: false }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (resources.length > 1) {
|
||||
// error but this should not happen because the nice id cant contain a dot and the alias has to have a dot and both have to be unique within the org so there should never be multiple matches
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: Multiple resources found matching the criteria`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const resource = resources[0];
|
||||
|
||||
const currentResourceAssociationCaches = await db
|
||||
.select()
|
||||
.from(clientSiteResourcesAssociationsCache)
|
||||
.where(
|
||||
and(
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.siteResourceId,
|
||||
resource.siteResourceId
|
||||
),
|
||||
eq(
|
||||
clientSiteResourcesAssociationsCache.clientId,
|
||||
client.clientId
|
||||
)
|
||||
)
|
||||
);
|
||||
|
||||
if (currentResourceAssociationCaches.length === 0) {
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to resource ${resource.siteResourceId}`
|
||||
);
|
||||
// cancel the request from the olm side to not keep doing this
|
||||
await sendToClient(
|
||||
olm.olmId,
|
||||
{
|
||||
type: "olm/wg/peer/chain/cancel",
|
||||
data: {
|
||||
chainId
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: false }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const siteIdFromResource = resource.siteId;
|
||||
|
||||
// get the site
|
||||
const [siteRes] = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(eq(sites.siteId, siteIdFromResource));
|
||||
if (!siteRes) {
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: Site with ID ${site} not found`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
site = siteRes;
|
||||
}
|
||||
|
||||
if (!site) {
|
||||
logger.error(`handleOlmServerPeerAddMessage: Site not found`);
|
||||
return;
|
||||
}
|
||||
|
||||
// check if the client can access this site using the cache
|
||||
const currentSiteAssociationCaches = await db
|
||||
.select()
|
||||
.from(clientSitesAssociationsCache)
|
||||
.where(
|
||||
and(
|
||||
eq(clientSitesAssociationsCache.clientId, client.clientId),
|
||||
eq(clientSitesAssociationsCache.siteId, site.siteId)
|
||||
)
|
||||
);
|
||||
|
||||
if (currentSiteAssociationCaches.length === 0) {
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: Client ${client.clientId} does not have access to site ${site.siteId}`
|
||||
);
|
||||
// cancel the request from the olm side to not keep doing this
|
||||
await sendToClient(
|
||||
olm.olmId,
|
||||
{
|
||||
type: "olm/wg/peer/chain/cancel",
|
||||
data: {
|
||||
chainId
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: false }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (!site.exitNodeId) {
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
|
||||
);
|
||||
// cancel the request from the olm side to not keep doing this
|
||||
await sendToClient(
|
||||
olm.olmId,
|
||||
{
|
||||
type: "olm/wg/peer/chain/cancel",
|
||||
data: {
|
||||
chainId
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: false }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// get the exit node from the side
|
||||
const [exitNode] = await db
|
||||
.select()
|
||||
.from(exitNodes)
|
||||
.where(eq(exitNodes.exitNodeId, site.exitNodeId));
|
||||
|
||||
if (!exitNode) {
|
||||
logger.error(
|
||||
`handleOlmServerPeerAddMessage: Site with ID ${site.siteId} has no exit node`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// also trigger the peer add handshake in case the peer was not already added to the olm and we need to hole punch
|
||||
// if it has already been added this will be a no-op
|
||||
await initPeerAddHandshake(
|
||||
// this will kick off the add peer process for the client
|
||||
client.clientId,
|
||||
{
|
||||
siteId: site.siteId,
|
||||
exitNode: {
|
||||
publicKey: exitNode.publicKey,
|
||||
endpoint: exitNode.endpoint
|
||||
}
|
||||
},
|
||||
olm.olmId,
|
||||
chainId
|
||||
);
|
||||
|
||||
return;
|
||||
};
|
||||
@@ -54,7 +54,7 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
|
||||
return;
|
||||
}
|
||||
|
||||
const { siteId } = message.data;
|
||||
const { siteId, chainId } = message.data;
|
||||
|
||||
// get the site
|
||||
const [site] = await db
|
||||
@@ -179,7 +179,8 @@ export const handleOlmServerPeerAddMessage: MessageHandler = async (
|
||||
),
|
||||
aliases: generateAliasConfig(
|
||||
allSiteResources.map(({ siteResources }) => siteResources)
|
||||
)
|
||||
),
|
||||
chainId: chainId,
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
|
||||
@@ -17,7 +17,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
|
||||
}
|
||||
|
||||
if (!olm.clientId) {
|
||||
logger.warn("Olm has no site!"); // TODO: Maybe we create the site here?
|
||||
logger.warn("Olm has no client!");
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
|
||||
return;
|
||||
}
|
||||
|
||||
const { siteId } = message.data;
|
||||
const { siteId, chainId } = message.data;
|
||||
|
||||
// Get the site
|
||||
const [site] = await db
|
||||
@@ -87,7 +87,8 @@ export const handleOlmUnRelayMessage: MessageHandler = async (context) => {
|
||||
type: "olm/wg/peer/unrelay",
|
||||
data: {
|
||||
siteId: siteId,
|
||||
endpoint: site.endpoint
|
||||
endpoint: site.endpoint,
|
||||
chainId
|
||||
}
|
||||
},
|
||||
broadcast: false,
|
||||
|
||||
@@ -3,9 +3,12 @@ export * from "./getOlmToken";
|
||||
export * from "./createUserOlm";
|
||||
export * from "./handleOlmRelayMessage";
|
||||
export * from "./handleOlmPingMessage";
|
||||
export * from "./deleteUserOlm";
|
||||
export * from "./archiveUserOlm";
|
||||
export * from "./unarchiveUserOlm";
|
||||
export * from "./listUserOlms";
|
||||
export * from "./deleteUserOlm";
|
||||
export * from "./getUserOlm";
|
||||
export * from "./handleOlmServerPeerAddMessage";
|
||||
export * from "./handleOlmUnRelayMessage";
|
||||
export * from "./recoverOlmWithFingerprint";
|
||||
export * from "./handleOlmDisconnectingMessage";
|
||||
export * from "./handleOlmServerInitAddPeerHandshake";
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import { db } from "@server/db";
|
||||
import { db, currentFingerprint } from "@server/db";
|
||||
import { olms } from "@server/db";
|
||||
import { eq, count, desc } from "drizzle-orm";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
@@ -9,6 +9,7 @@ import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import logger from "@server/logger";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { getUserDeviceName } from "@server/db/names";
|
||||
|
||||
const querySchema = z.object({
|
||||
limit: z
|
||||
@@ -51,6 +52,7 @@ export type ListUserOlmsResponse = {
|
||||
name: string | null;
|
||||
clientId: number | null;
|
||||
userId: string | null;
|
||||
archived: boolean;
|
||||
}>;
|
||||
pagination: {
|
||||
total: number;
|
||||
@@ -89,7 +91,7 @@ export async function listUserOlms(
|
||||
|
||||
const { userId } = parsedParams.data;
|
||||
|
||||
// Get total count
|
||||
// Get total count (including archived OLMs)
|
||||
const [totalCountResult] = await db
|
||||
.select({ count: count() })
|
||||
.from(olms)
|
||||
@@ -97,22 +99,34 @@ export async function listUserOlms(
|
||||
|
||||
const total = totalCountResult?.count || 0;
|
||||
|
||||
// Get OLMs for the current user
|
||||
const userOlms = await db
|
||||
.select({
|
||||
olmId: olms.olmId,
|
||||
dateCreated: olms.dateCreated,
|
||||
version: olms.version,
|
||||
name: olms.name,
|
||||
clientId: olms.clientId,
|
||||
userId: olms.userId
|
||||
})
|
||||
// Get OLMs for the current user (including archived OLMs)
|
||||
const list = await db
|
||||
.select()
|
||||
.from(olms)
|
||||
.where(eq(olms.userId, userId))
|
||||
.leftJoin(
|
||||
currentFingerprint,
|
||||
eq(olms.olmId, currentFingerprint.olmId)
|
||||
)
|
||||
.orderBy(desc(olms.dateCreated))
|
||||
.limit(limit)
|
||||
.offset(offset);
|
||||
|
||||
const userOlms = list.map((item) => {
|
||||
const model = item.currentFingerprint?.deviceModel || null;
|
||||
const newName = getUserDeviceName(model, item.olms.name);
|
||||
|
||||
return {
|
||||
olmId: item.olms.olmId,
|
||||
dateCreated: item.olms.dateCreated,
|
||||
version: item.olms.version,
|
||||
name: newName,
|
||||
clientId: item.olms.clientId,
|
||||
userId: item.olms.userId,
|
||||
archived: item.olms.archived
|
||||
};
|
||||
});
|
||||
|
||||
return response<ListUserOlmsResponse>(res, {
|
||||
data: {
|
||||
olms: userOlms,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import { sendToClient } from "#dynamic/routers/ws";
|
||||
import { db, olms } from "@server/db";
|
||||
import { clientSitesAssociationsCache, db, olms } from "@server/db";
|
||||
import { canCompress } from "@server/lib/clientVersionChecks";
|
||||
import config from "@server/lib/config";
|
||||
import logger from "@server/logger";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { Alias } from "yaml";
|
||||
|
||||
export async function addPeer(
|
||||
@@ -17,7 +19,8 @@ export async function addPeer(
|
||||
remoteSubnets: string[] | null; // optional, comma-separated list of subnets that this site can access
|
||||
aliases: Alias[];
|
||||
},
|
||||
olmId?: string
|
||||
olmId?: string,
|
||||
version?: string | null
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
@@ -29,22 +32,27 @@ export async function addPeer(
|
||||
return; // ignore this because an olm might not be associated with the client anymore
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
version = olm.version;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/add",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
name: peer.name,
|
||||
publicKey: peer.publicKey,
|
||||
endpoint: peer.endpoint,
|
||||
relayEndpoint: peer.relayEndpoint,
|
||||
serverIP: peer.serverIP,
|
||||
serverPort: peer.serverPort,
|
||||
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
|
||||
aliases: peer.aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
await sendToClient(
|
||||
olmId,
|
||||
{
|
||||
type: "olm/wg/peer/add",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
name: peer.name,
|
||||
publicKey: peer.publicKey,
|
||||
endpoint: peer.endpoint,
|
||||
relayEndpoint: peer.relayEndpoint,
|
||||
serverIP: peer.serverIP,
|
||||
serverPort: peer.serverPort,
|
||||
remoteSubnets: peer.remoteSubnets, // optional, comma-separated list of subnets that this site can access
|
||||
aliases: peer.aliases
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
@@ -55,7 +63,8 @@ export async function deletePeer(
|
||||
clientId: number,
|
||||
siteId: number,
|
||||
publicKey: string,
|
||||
olmId?: string
|
||||
olmId?: string,
|
||||
version?: string | null
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
@@ -67,15 +76,20 @@ export async function deletePeer(
|
||||
return;
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
version = olm.version;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/remove",
|
||||
data: {
|
||||
publicKey,
|
||||
siteId: siteId
|
||||
}
|
||||
}).catch((error) => {
|
||||
await sendToClient(
|
||||
olmId,
|
||||
{
|
||||
type: "olm/wg/peer/remove",
|
||||
data: {
|
||||
publicKey,
|
||||
siteId: siteId
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
@@ -94,7 +108,8 @@ export async function updatePeer(
|
||||
remoteSubnets?: string[] | null; // optional, comma-separated list of subnets that
|
||||
aliases?: Alias[] | null;
|
||||
},
|
||||
olmId?: string
|
||||
olmId?: string,
|
||||
version?: string | null
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
@@ -106,21 +121,26 @@ export async function updatePeer(
|
||||
return;
|
||||
}
|
||||
olmId = olm.olmId;
|
||||
version = olm.version;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/update",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
publicKey: peer.publicKey,
|
||||
endpoint: peer.endpoint,
|
||||
relayEndpoint: peer.relayEndpoint,
|
||||
serverIP: peer.serverIP,
|
||||
serverPort: peer.serverPort,
|
||||
remoteSubnets: peer.remoteSubnets,
|
||||
aliases: peer.aliases
|
||||
}
|
||||
}).catch((error) => {
|
||||
await sendToClient(
|
||||
olmId,
|
||||
{
|
||||
type: "olm/wg/peer/update",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
publicKey: peer.publicKey,
|
||||
endpoint: peer.endpoint,
|
||||
relayEndpoint: peer.relayEndpoint,
|
||||
serverIP: peer.serverIP,
|
||||
serverPort: peer.serverPort,
|
||||
remoteSubnets: peer.remoteSubnets,
|
||||
aliases: peer.aliases
|
||||
}
|
||||
},
|
||||
{ incrementConfigVersion: true, compress: canCompress(version, "olm") }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
@@ -136,7 +156,8 @@ export async function initPeerAddHandshake(
|
||||
endpoint: string;
|
||||
};
|
||||
},
|
||||
olmId?: string
|
||||
olmId?: string,
|
||||
chainId?: string
|
||||
) {
|
||||
if (!olmId) {
|
||||
const [olm] = await db
|
||||
@@ -150,19 +171,36 @@ export async function initPeerAddHandshake(
|
||||
olmId = olm.olmId;
|
||||
}
|
||||
|
||||
await sendToClient(olmId, {
|
||||
type: "olm/wg/peer/holepunch/site/add",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
exitNode: {
|
||||
publicKey: peer.exitNode.publicKey,
|
||||
endpoint: peer.exitNode.endpoint
|
||||
await sendToClient(
|
||||
olmId,
|
||||
{
|
||||
type: "olm/wg/peer/holepunch/site/add",
|
||||
data: {
|
||||
siteId: peer.siteId,
|
||||
exitNode: {
|
||||
publicKey: peer.exitNode.publicKey,
|
||||
relayPort: config.getRawConfig().gerbil.clients_start_port,
|
||||
endpoint: peer.exitNode.endpoint
|
||||
},
|
||||
chainId
|
||||
}
|
||||
}
|
||||
}).catch((error) => {
|
||||
},
|
||||
{ incrementConfigVersion: true }
|
||||
).catch((error) => {
|
||||
logger.warn(`Error sending message:`, error);
|
||||
});
|
||||
|
||||
// update the clientSiteAssociationsCache to make the isJitMode flag false so that JIT mode is disabled for this site if it restarts or something after the connection
|
||||
await db
|
||||
.update(clientSitesAssociationsCache)
|
||||
.set({ isJitMode: false })
|
||||
.where(
|
||||
and(
|
||||
eq(clientSitesAssociationsCache.clientId, clientId),
|
||||
eq(clientSitesAssociationsCache.siteId, peer.siteId)
|
||||
)
|
||||
);
|
||||
|
||||
logger.info(
|
||||
`Initiated peer add handshake for site ${peer.siteId} to olm ${olmId}`
|
||||
);
|
||||
|
||||
125
server/routers/olm/recoverOlmWithFingerprint.ts
Normal file
125
server/routers/olm/recoverOlmWithFingerprint.ts
Normal file
@@ -0,0 +1,125 @@
|
||||
import { db, currentFingerprint, olms } from "@server/db";
|
||||
import logger from "@server/logger";
|
||||
import HttpCode from "@server/types/HttpCode";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { NextFunction, Request, Response } from "express";
|
||||
import response from "@server/lib/response";
|
||||
import createHttpError from "http-errors";
|
||||
import { z } from "zod";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { generateId } from "@server/auth/sessions/app";
|
||||
import { hashPassword } from "@server/auth/password";
|
||||
|
||||
const paramsSchema = z
|
||||
.object({
|
||||
userId: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
const bodySchema = z
|
||||
.object({
|
||||
platformFingerprint: z.string()
|
||||
})
|
||||
.strict();
|
||||
|
||||
export async function recoverOlmWithFingerprint(
|
||||
req: Request,
|
||||
res: Response,
|
||||
next: NextFunction
|
||||
): Promise<any> {
|
||||
try {
|
||||
const parsedParams = paramsSchema.safeParse(req.params);
|
||||
if (!parsedParams.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedParams.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { userId } = parsedParams.data;
|
||||
|
||||
const parsedBody = bodySchema.safeParse(req.body);
|
||||
if (!parsedBody.success) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
fromError(parsedBody.error).toString()
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const { platformFingerprint } = parsedBody.data;
|
||||
|
||||
const result = await db
|
||||
.select({
|
||||
olm: olms,
|
||||
fingerprint: currentFingerprint
|
||||
})
|
||||
.from(olms)
|
||||
.innerJoin(
|
||||
currentFingerprint,
|
||||
eq(currentFingerprint.olmId, olms.olmId)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
eq(olms.userId, userId),
|
||||
eq(
|
||||
currentFingerprint.platformFingerprint,
|
||||
platformFingerprint
|
||||
)
|
||||
)
|
||||
)
|
||||
.orderBy(currentFingerprint.lastSeen);
|
||||
|
||||
if (!result || result.length == 0) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.NOT_FOUND,
|
||||
"corresponding olm with this fingerprint not found"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (result.length > 1) {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.CONFLICT,
|
||||
"multiple matching fingerprints found, not resetting secrets"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
const [{ olm: foundOlm }] = result;
|
||||
|
||||
const newSecret = generateId(48);
|
||||
const newSecretHash = await hashPassword(newSecret);
|
||||
|
||||
await db
|
||||
.update(olms)
|
||||
.set({
|
||||
secretHash: newSecretHash
|
||||
})
|
||||
.where(eq(olms.olmId, foundOlm.olmId));
|
||||
|
||||
return response(res, {
|
||||
data: {
|
||||
olmId: foundOlm.olmId,
|
||||
secret: newSecret
|
||||
},
|
||||
success: true,
|
||||
error: false,
|
||||
message: "Successfully retrieved olm",
|
||||
status: HttpCode.OK
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error(error);
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.INTERNAL_SERVER_ERROR,
|
||||
"Failed to recover olm using provided fingerprint input"
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user