mirror of
https://github.com/fosrl/pangolin.git
synced 2026-03-26 04:26:38 +00:00
Merge branch 'dev' into multi-role
This commit is contained in:
@@ -32,7 +32,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/logs/access/export",
|
||||
description: "Export the access audit log for an organization as CSV",
|
||||
tags: [OpenAPITags.Org],
|
||||
tags: [OpenAPITags.Logs],
|
||||
request: {
|
||||
query: queryAccessAuditLogsQuery,
|
||||
params: queryAccessAuditLogsParams
|
||||
|
||||
@@ -32,7 +32,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/logs/action/export",
|
||||
description: "Export the action audit log for an organization as CSV",
|
||||
tags: [OpenAPITags.Org],
|
||||
tags: [OpenAPITags.Logs],
|
||||
request: {
|
||||
query: queryActionAuditLogsQuery,
|
||||
params: queryActionAuditLogsParams
|
||||
|
||||
@@ -11,11 +11,11 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { accessAuditLog, db, resources } from "@server/db";
|
||||
import { accessAuditLog, logsDb, resources, db, primaryDb } from "@server/db";
|
||||
import { registry } from "@server/openApi";
|
||||
import { NextFunction } from "express";
|
||||
import { Request, Response } from "express";
|
||||
import { eq, gt, lt, and, count, desc } from "drizzle-orm";
|
||||
import { eq, gt, lt, and, count, desc, inArray } from "drizzle-orm";
|
||||
import { OpenAPITags } from "@server/openApi";
|
||||
import { z } from "zod";
|
||||
import createHttpError from "http-errors";
|
||||
@@ -115,15 +115,13 @@ function getWhere(data: Q) {
|
||||
}
|
||||
|
||||
export function queryAccess(data: Q) {
|
||||
return db
|
||||
return logsDb
|
||||
.select({
|
||||
orgId: accessAuditLog.orgId,
|
||||
action: accessAuditLog.action,
|
||||
actorType: accessAuditLog.actorType,
|
||||
actorId: accessAuditLog.actorId,
|
||||
resourceId: accessAuditLog.resourceId,
|
||||
resourceName: resources.name,
|
||||
resourceNiceId: resources.niceId,
|
||||
ip: accessAuditLog.ip,
|
||||
location: accessAuditLog.location,
|
||||
userAgent: accessAuditLog.userAgent,
|
||||
@@ -133,16 +131,46 @@ export function queryAccess(data: Q) {
|
||||
actor: accessAuditLog.actor
|
||||
})
|
||||
.from(accessAuditLog)
|
||||
.leftJoin(
|
||||
resources,
|
||||
eq(accessAuditLog.resourceId, resources.resourceId)
|
||||
)
|
||||
.where(getWhere(data))
|
||||
.orderBy(desc(accessAuditLog.timestamp), desc(accessAuditLog.id));
|
||||
}
|
||||
|
||||
async function enrichWithResourceDetails(logs: Awaited<ReturnType<typeof queryAccess>>) {
|
||||
// If logs database is the same as main database, we can do a join
|
||||
// Otherwise, we need to fetch resource details separately
|
||||
const resourceIds = logs
|
||||
.map(log => log.resourceId)
|
||||
.filter((id): id is number => id !== null && id !== undefined);
|
||||
|
||||
if (resourceIds.length === 0) {
|
||||
return logs.map(log => ({ ...log, resourceName: null, resourceNiceId: null }));
|
||||
}
|
||||
|
||||
// Fetch resource details from main database
|
||||
const resourceDetails = await primaryDb
|
||||
.select({
|
||||
resourceId: resources.resourceId,
|
||||
name: resources.name,
|
||||
niceId: resources.niceId
|
||||
})
|
||||
.from(resources)
|
||||
.where(inArray(resources.resourceId, resourceIds));
|
||||
|
||||
// Create a map for quick lookup
|
||||
const resourceMap = new Map(
|
||||
resourceDetails.map(r => [r.resourceId, { name: r.name, niceId: r.niceId }])
|
||||
);
|
||||
|
||||
// Enrich logs with resource details
|
||||
return logs.map(log => ({
|
||||
...log,
|
||||
resourceName: log.resourceId ? resourceMap.get(log.resourceId)?.name ?? null : null,
|
||||
resourceNiceId: log.resourceId ? resourceMap.get(log.resourceId)?.niceId ?? null : null
|
||||
}));
|
||||
}
|
||||
|
||||
export function countAccessQuery(data: Q) {
|
||||
const countQuery = db
|
||||
const countQuery = logsDb
|
||||
.select({ count: count() })
|
||||
.from(accessAuditLog)
|
||||
.where(getWhere(data));
|
||||
@@ -161,7 +189,7 @@ async function queryUniqueFilterAttributes(
|
||||
);
|
||||
|
||||
// Get unique actors
|
||||
const uniqueActors = await db
|
||||
const uniqueActors = await logsDb
|
||||
.selectDistinct({
|
||||
actor: accessAuditLog.actor
|
||||
})
|
||||
@@ -169,7 +197,7 @@ async function queryUniqueFilterAttributes(
|
||||
.where(baseConditions);
|
||||
|
||||
// Get unique locations
|
||||
const uniqueLocations = await db
|
||||
const uniqueLocations = await logsDb
|
||||
.selectDistinct({
|
||||
locations: accessAuditLog.location
|
||||
})
|
||||
@@ -177,25 +205,40 @@ async function queryUniqueFilterAttributes(
|
||||
.where(baseConditions);
|
||||
|
||||
// Get unique resources with names
|
||||
const uniqueResources = await db
|
||||
const uniqueResources = await logsDb
|
||||
.selectDistinct({
|
||||
id: accessAuditLog.resourceId,
|
||||
name: resources.name
|
||||
id: accessAuditLog.resourceId
|
||||
})
|
||||
.from(accessAuditLog)
|
||||
.leftJoin(
|
||||
resources,
|
||||
eq(accessAuditLog.resourceId, resources.resourceId)
|
||||
)
|
||||
.where(baseConditions);
|
||||
|
||||
// Fetch resource names from main database for the unique resource IDs
|
||||
const resourceIds = uniqueResources
|
||||
.map(row => row.id)
|
||||
.filter((id): id is number => id !== null);
|
||||
|
||||
let resourcesWithNames: Array<{ id: number; name: string | null }> = [];
|
||||
|
||||
if (resourceIds.length > 0) {
|
||||
const resourceDetails = await primaryDb
|
||||
.select({
|
||||
resourceId: resources.resourceId,
|
||||
name: resources.name
|
||||
})
|
||||
.from(resources)
|
||||
.where(inArray(resources.resourceId, resourceIds));
|
||||
|
||||
resourcesWithNames = resourceDetails.map(r => ({
|
||||
id: r.resourceId,
|
||||
name: r.name
|
||||
}));
|
||||
}
|
||||
|
||||
return {
|
||||
actors: uniqueActors
|
||||
.map((row) => row.actor)
|
||||
.filter((actor): actor is string => actor !== null),
|
||||
resources: uniqueResources.filter(
|
||||
(row): row is { id: number; name: string | null } => row.id !== null
|
||||
),
|
||||
resources: resourcesWithNames,
|
||||
locations: uniqueLocations
|
||||
.map((row) => row.locations)
|
||||
.filter((location): location is string => location !== null)
|
||||
@@ -206,7 +249,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/logs/access",
|
||||
description: "Query the access audit log for an organization",
|
||||
tags: [OpenAPITags.Org],
|
||||
tags: [OpenAPITags.Logs],
|
||||
request: {
|
||||
query: queryAccessAuditLogsQuery,
|
||||
params: queryAccessAuditLogsParams
|
||||
@@ -243,7 +286,10 @@ export async function queryAccessAuditLogs(
|
||||
|
||||
const baseQuery = queryAccess(data);
|
||||
|
||||
const log = await baseQuery.limit(data.limit).offset(data.offset);
|
||||
const logsRaw = await baseQuery.limit(data.limit).offset(data.offset);
|
||||
|
||||
// Enrich with resource details (handles cross-database scenario)
|
||||
const log = await enrichWithResourceDetails(logsRaw);
|
||||
|
||||
const totalCountResult = await countAccessQuery(data);
|
||||
const totalCount = totalCountResult[0].count;
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
* This file is not licensed under the AGPLv3.
|
||||
*/
|
||||
|
||||
import { actionAuditLog, db } from "@server/db";
|
||||
import { actionAuditLog, logsDb } from "@server/db";
|
||||
import { registry } from "@server/openApi";
|
||||
import { NextFunction } from "express";
|
||||
import { Request, Response } from "express";
|
||||
@@ -97,7 +97,7 @@ function getWhere(data: Q) {
|
||||
}
|
||||
|
||||
export function queryAction(data: Q) {
|
||||
return db
|
||||
return logsDb
|
||||
.select({
|
||||
orgId: actionAuditLog.orgId,
|
||||
action: actionAuditLog.action,
|
||||
@@ -113,7 +113,7 @@ export function queryAction(data: Q) {
|
||||
}
|
||||
|
||||
export function countActionQuery(data: Q) {
|
||||
const countQuery = db
|
||||
const countQuery = logsDb
|
||||
.select({ count: count() })
|
||||
.from(actionAuditLog)
|
||||
.where(getWhere(data));
|
||||
@@ -132,14 +132,14 @@ async function queryUniqueFilterAttributes(
|
||||
);
|
||||
|
||||
// Get unique actors
|
||||
const uniqueActors = await db
|
||||
const uniqueActors = await logsDb
|
||||
.selectDistinct({
|
||||
actor: actionAuditLog.actor
|
||||
})
|
||||
.from(actionAuditLog)
|
||||
.where(baseConditions);
|
||||
|
||||
const uniqueActions = await db
|
||||
const uniqueActions = await logsDb
|
||||
.selectDistinct({
|
||||
action: actionAuditLog.action
|
||||
})
|
||||
@@ -160,7 +160,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/logs/action",
|
||||
description: "Query the action audit log for an organization",
|
||||
tags: [OpenAPITags.Org],
|
||||
tags: [OpenAPITags.Logs],
|
||||
request: {
|
||||
query: queryActionAuditLogsQuery,
|
||||
params: queryActionAuditLogsParams
|
||||
|
||||
@@ -31,16 +31,16 @@ const getOrgSchema = z.strictObject({
|
||||
orgId: z.string()
|
||||
});
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/billing/usage",
|
||||
description: "Get an organization's billing usage",
|
||||
tags: [OpenAPITags.Org],
|
||||
request: {
|
||||
params: getOrgSchema
|
||||
},
|
||||
responses: {}
|
||||
});
|
||||
// registry.registerPath({
|
||||
// method: "get",
|
||||
// path: "/org/{orgId}/billing/usage",
|
||||
// description: "Get an organization's billing usage",
|
||||
// tags: [OpenAPITags.Org],
|
||||
// request: {
|
||||
// params: getOrgSchema
|
||||
// },
|
||||
// responses: {}
|
||||
// });
|
||||
|
||||
export async function getOrgUsage(
|
||||
req: Request,
|
||||
|
||||
@@ -24,7 +24,6 @@ import { eq, and } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import stripe from "#private/lib/stripe";
|
||||
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
|
||||
import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
|
||||
import { getSubType } from "./getSubType";
|
||||
import privateConfig from "#private/lib/config";
|
||||
import { getLicensePriceSet, LicenseId } from "@server/lib/billing/licenses";
|
||||
@@ -172,7 +171,7 @@ export async function handleSubscriptionCreated(
|
||||
const email = orgUserRes.user.email;
|
||||
|
||||
if (email) {
|
||||
moveEmailToAudience(email, AudienceIds.Subscribed);
|
||||
// TODO: update user in Sendy
|
||||
}
|
||||
}
|
||||
} else if (type === "license") {
|
||||
|
||||
@@ -23,7 +23,6 @@ import {
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import logger from "@server/logger";
|
||||
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
|
||||
import { AudienceIds, moveEmailToAudience } from "#private/lib/resend";
|
||||
import { getSubType } from "./getSubType";
|
||||
import stripe from "#private/lib/stripe";
|
||||
import privateConfig from "#private/lib/config";
|
||||
@@ -109,7 +108,7 @@ export async function handleSubscriptionDeleted(
|
||||
const email = orgUserRes.user.email;
|
||||
|
||||
if (email) {
|
||||
moveEmailToAudience(email, AudienceIds.Churned);
|
||||
// TODO: update user in Sendy
|
||||
}
|
||||
}
|
||||
} else if (type === "license") {
|
||||
|
||||
@@ -480,9 +480,9 @@ authenticated.get(
|
||||
|
||||
authenticated.post(
|
||||
"/re-key/:clientId/regenerate-client-secret",
|
||||
verifyClientAccess, // this is first to set the org id
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.rotateCredentials),
|
||||
verifyClientAccess, // this is first to set the org id
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateClientSecret
|
||||
@@ -490,9 +490,9 @@ authenticated.post(
|
||||
|
||||
authenticated.post(
|
||||
"/re-key/:siteId/regenerate-site-secret",
|
||||
verifySiteAccess, // this is first to set the org id
|
||||
verifyValidLicense,
|
||||
verifyValidSubscription(tierMatrix.rotateCredentials),
|
||||
verifySiteAccess, // this is first to set the org id
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.reGenerateSecret),
|
||||
reKey.reGenerateSiteSecret
|
||||
@@ -515,6 +515,6 @@ authenticated.post(
|
||||
verifyOrgAccess,
|
||||
verifyLimits,
|
||||
verifyUserHasAction(ActionsEnum.signSshKey),
|
||||
logActionAudit(ActionsEnum.signSshKey),
|
||||
// logActionAudit(ActionsEnum.signSshKey), // it is handled inside of the function below so we can include more metadata
|
||||
ssh.signSshKey
|
||||
);
|
||||
|
||||
@@ -15,6 +15,7 @@ import { verifySessionRemoteExitNodeMiddleware } from "#private/middlewares/veri
|
||||
import { Router } from "express";
|
||||
import {
|
||||
db,
|
||||
logsDb,
|
||||
exitNodes,
|
||||
Resource,
|
||||
ResourcePassword,
|
||||
@@ -81,6 +82,7 @@ import { verifyResourceAccessToken } from "@server/auth/verifyResourceAccessToke
|
||||
import semver from "semver";
|
||||
import { maxmindAsnLookup } from "@server/db/maxmindAsn";
|
||||
import { checkOrgAccessPolicy } from "@server/lib/checkOrgAccessPolicy";
|
||||
import { sanitizeString } from "@server/lib/sanitize";
|
||||
|
||||
// Zod schemas for request validation
|
||||
const getResourceByDomainParamsSchema = z.strictObject({
|
||||
@@ -1859,24 +1861,24 @@ hybridRouter.post(
|
||||
})
|
||||
.map((logEntry) => ({
|
||||
timestamp: logEntry.timestamp,
|
||||
orgId: logEntry.orgId,
|
||||
actorType: logEntry.actorType,
|
||||
actor: logEntry.actor,
|
||||
actorId: logEntry.actorId,
|
||||
metadata: logEntry.metadata,
|
||||
orgId: sanitizeString(logEntry.orgId),
|
||||
actorType: sanitizeString(logEntry.actorType),
|
||||
actor: sanitizeString(logEntry.actor),
|
||||
actorId: sanitizeString(logEntry.actorId),
|
||||
metadata: sanitizeString(logEntry.metadata),
|
||||
action: logEntry.action,
|
||||
resourceId: logEntry.resourceId,
|
||||
reason: logEntry.reason,
|
||||
location: logEntry.location,
|
||||
location: sanitizeString(logEntry.location),
|
||||
// userAgent: data.userAgent, // TODO: add this
|
||||
// headers: data.body.headers,
|
||||
// query: data.body.query,
|
||||
originalRequestURL: logEntry.originalRequestURL,
|
||||
scheme: logEntry.scheme,
|
||||
host: logEntry.host,
|
||||
path: logEntry.path,
|
||||
method: logEntry.method,
|
||||
ip: logEntry.ip,
|
||||
originalRequestURL: sanitizeString(logEntry.originalRequestURL) ?? "",
|
||||
scheme: sanitizeString(logEntry.scheme) ?? "",
|
||||
host: sanitizeString(logEntry.host) ?? "",
|
||||
path: sanitizeString(logEntry.path) ?? "",
|
||||
method: sanitizeString(logEntry.method) ?? "",
|
||||
ip: sanitizeString(logEntry.ip),
|
||||
tls: logEntry.tls
|
||||
}));
|
||||
|
||||
@@ -1884,7 +1886,7 @@ hybridRouter.post(
|
||||
const batchSize = 100;
|
||||
for (let i = 0; i < logEntries.length; i += batchSize) {
|
||||
const batch = logEntries.slice(i, i + batchSize);
|
||||
await db.insert(requestAuditLog).values(batch);
|
||||
await logsDb.insert(requestAuditLog).values(batch);
|
||||
}
|
||||
|
||||
return response(res, {
|
||||
|
||||
@@ -52,7 +52,7 @@ registry.registerPath({
|
||||
method: "put",
|
||||
path: "/org/{orgId}/idp/oidc",
|
||||
description: "Create an OIDC IdP for a specific organization.",
|
||||
tags: [OpenAPITags.Idp, OpenAPITags.Org],
|
||||
tags: [OpenAPITags.OrgIdp],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
body: {
|
||||
|
||||
@@ -35,7 +35,7 @@ registry.registerPath({
|
||||
method: "delete",
|
||||
path: "/org/{orgId}/idp/{idpId}",
|
||||
description: "Delete IDP for a specific organization.",
|
||||
tags: [OpenAPITags.Idp, OpenAPITags.Org],
|
||||
tags: [OpenAPITags.OrgIdp],
|
||||
request: {
|
||||
params: paramsSchema
|
||||
},
|
||||
|
||||
@@ -50,9 +50,9 @@ async function query(idpId: number, orgId: string) {
|
||||
|
||||
registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/:orgId/idp/:idpId",
|
||||
path: "/org/{orgId}/idp/{idpId}",
|
||||
description: "Get an IDP by its IDP ID for a specific organization.",
|
||||
tags: [OpenAPITags.Idp, OpenAPITags.Org],
|
||||
tags: [OpenAPITags.OrgIdp],
|
||||
request: {
|
||||
params: paramsSchema
|
||||
},
|
||||
|
||||
@@ -67,7 +67,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/org/{orgId}/idp",
|
||||
description: "List all IDP for a specific organization.",
|
||||
tags: [OpenAPITags.Idp, OpenAPITags.Org],
|
||||
tags: [OpenAPITags.OrgIdp],
|
||||
request: {
|
||||
query: querySchema,
|
||||
params: paramsSchema
|
||||
|
||||
@@ -59,7 +59,7 @@ registry.registerPath({
|
||||
method: "post",
|
||||
path: "/org/{orgId}/idp/{idpId}/oidc",
|
||||
description: "Update an OIDC IdP for a specific organization.",
|
||||
tags: [OpenAPITags.Idp, OpenAPITags.Org],
|
||||
tags: [OpenAPITags.OrgIdp],
|
||||
request: {
|
||||
params: paramsSchema,
|
||||
body: {
|
||||
|
||||
@@ -38,7 +38,7 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
|
||||
);
|
||||
|
||||
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
|
||||
const newlyOfflineNodes = await db
|
||||
const offlineNodes = await db
|
||||
.update(exitNodes)
|
||||
.set({ online: false })
|
||||
.where(
|
||||
@@ -53,32 +53,15 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
|
||||
)
|
||||
.returning();
|
||||
|
||||
// Update the sites to offline if they have not pinged either
|
||||
const exitNodeIds = newlyOfflineNodes.map(
|
||||
(node) => node.exitNodeId
|
||||
);
|
||||
|
||||
const sitesOnNode = await db
|
||||
.select()
|
||||
.from(sites)
|
||||
.where(
|
||||
and(
|
||||
eq(sites.online, true),
|
||||
inArray(sites.exitNodeId, exitNodeIds)
|
||||
)
|
||||
if (offlineNodes.length > 0) {
|
||||
logger.info(
|
||||
`checkRemoteExitNodeOffline: Marked ${offlineNodes.length} remoteExitNode client(s) offline due to inactivity`
|
||||
);
|
||||
|
||||
// loop through the sites and process their lastBandwidthUpdate as an iso string and if its more than 1 minute old then mark the site offline
|
||||
for (const site of sitesOnNode) {
|
||||
if (!site.lastBandwidthUpdate) {
|
||||
continue;
|
||||
}
|
||||
const lastBandwidthUpdate = new Date(site.lastBandwidthUpdate);
|
||||
if (Date.now() - lastBandwidthUpdate.getTime() > 60 * 1000) {
|
||||
await db
|
||||
.update(sites)
|
||||
.set({ online: false })
|
||||
.where(eq(sites.siteId, site.siteId));
|
||||
for (const offlineClient of offlineNodes) {
|
||||
logger.debug(
|
||||
`checkRemoteExitNodeOffline: Client ${offlineClient.exitNodeId} marked offline (lastPing: ${offlineClient.lastPing})`
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
|
||||
@@ -52,7 +52,7 @@ registry.registerPath({
|
||||
method: "get",
|
||||
path: "/maintenance/info",
|
||||
description: "Get maintenance information for a resource by domain.",
|
||||
tags: [OpenAPITags.Resource],
|
||||
tags: [OpenAPITags.PublicResource],
|
||||
request: {
|
||||
query: z.object({
|
||||
fullDomain: z.string()
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
import { Request, Response, NextFunction } from "express";
|
||||
import { z } from "zod";
|
||||
import {
|
||||
actionAuditLog,
|
||||
db,
|
||||
logsDb,
|
||||
newts,
|
||||
roles,
|
||||
roundTripMessageTracker,
|
||||
@@ -29,12 +31,12 @@ import HttpCode from "@server/types/HttpCode";
|
||||
import createHttpError from "http-errors";
|
||||
import logger from "@server/logger";
|
||||
import { fromError } from "zod-validation-error";
|
||||
import { OpenAPITags, registry } from "@server/openApi";
|
||||
import { and, eq, inArray, or } from "drizzle-orm";
|
||||
import { canUserAccessSiteResource } from "@server/auth/canUserAccessSiteResource";
|
||||
import { signPublicKey, getOrgCAKeys } from "#private/lib/sshCA";
|
||||
import { signPublicKey, getOrgCAKeys } from "@server/lib/sshCA";
|
||||
import config from "@server/lib/config";
|
||||
import { sendToClient } from "#private/routers/ws";
|
||||
import { ActionsEnum } from "@server/auth/actions";
|
||||
|
||||
const paramsSchema = z.strictObject({
|
||||
orgId: z.string().nonempty()
|
||||
@@ -64,6 +66,7 @@ export type SignSshKeyResponse = {
|
||||
sshUsername: string;
|
||||
sshHost: string;
|
||||
resourceId: number;
|
||||
siteId: number;
|
||||
keyId: string;
|
||||
validPrincipals: string[];
|
||||
validAfter: string;
|
||||
@@ -185,7 +188,7 @@ export async function signSshKey(
|
||||
} else if (req.user?.username) {
|
||||
usernameToUse = req.user.username;
|
||||
// We need to clean out any spaces or special characters from the username to ensure it's valid for SSH certificates
|
||||
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "");
|
||||
usernameToUse = usernameToUse.replace(/[^a-zA-Z0-9_-]/g, "-");
|
||||
if (!usernameToUse) {
|
||||
return next(
|
||||
createHttpError(
|
||||
@@ -203,6 +206,9 @@ export async function signSshKey(
|
||||
);
|
||||
}
|
||||
|
||||
// prefix with p-
|
||||
usernameToUse = `p-${usernameToUse}`;
|
||||
|
||||
// check if we have a existing user in this org with the same
|
||||
const [existingUserWithSameName] = await db
|
||||
.select()
|
||||
@@ -248,6 +254,16 @@ export async function signSshKey(
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
await db
|
||||
.update(userOrgs)
|
||||
.set({ pamUsername: usernameToUse })
|
||||
.where(
|
||||
and(
|
||||
eq(userOrgs.orgId, orgId),
|
||||
eq(userOrgs.userId, userId)
|
||||
)
|
||||
);
|
||||
} else {
|
||||
usernameToUse = userOrg.pamUsername;
|
||||
}
|
||||
@@ -319,7 +335,16 @@ export async function signSshKey(
|
||||
);
|
||||
}
|
||||
|
||||
// Check if the user has access to the resource (any of their roles)
|
||||
if (resource.mode == "cidr") {
|
||||
return next(
|
||||
createHttpError(
|
||||
HttpCode.BAD_REQUEST,
|
||||
"SSHing is not supported for CIDR resources"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Check if the user has access to the resource
|
||||
const hasAccess = await canUserAccessSiteResource({
|
||||
userId: userId,
|
||||
resourceId: resource.siteResourceId,
|
||||
@@ -444,6 +469,20 @@ export async function signSshKey(
|
||||
sshHost = resource.destination;
|
||||
}
|
||||
|
||||
await logsDb.insert(actionAuditLog).values({
|
||||
timestamp: Math.floor(Date.now() / 1000),
|
||||
orgId: orgId,
|
||||
actorType: "user",
|
||||
actor: req.user?.username ?? "",
|
||||
actorId: req.user?.userId ?? "",
|
||||
action: ActionsEnum.signSshKey,
|
||||
metadata: JSON.stringify({
|
||||
resourceId: resource.siteResourceId,
|
||||
resource: resource.name,
|
||||
siteId: resource.siteId,
|
||||
})
|
||||
});
|
||||
|
||||
return response<SignSshKeyResponse>(res, {
|
||||
data: {
|
||||
certificate: cert.certificate,
|
||||
@@ -451,6 +490,7 @@ export async function signSshKey(
|
||||
sshUsername: usernameToUse,
|
||||
sshHost: sshHost,
|
||||
resourceId: resource.siteResourceId,
|
||||
siteId: resource.siteId,
|
||||
keyId: cert.keyId,
|
||||
validPrincipals: cert.validPrincipals,
|
||||
validAfter: cert.validAfter.toISOString(),
|
||||
|
||||
@@ -17,10 +17,13 @@ import {
|
||||
startRemoteExitNodeOfflineChecker
|
||||
} from "#private/routers/remoteExitNode";
|
||||
import { MessageHandler } from "@server/routers/ws";
|
||||
import { build } from "@server/build";
|
||||
|
||||
export const messageHandlers: Record<string, MessageHandler> = {
|
||||
"remoteExitNode/register": handleRemoteExitNodeRegisterMessage,
|
||||
"remoteExitNode/ping": handleRemoteExitNodePingMessage
|
||||
};
|
||||
|
||||
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
|
||||
if (build != "saas") {
|
||||
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
*/
|
||||
|
||||
import { Router, Request, Response } from "express";
|
||||
import zlib from "zlib";
|
||||
import { Server as HttpServer } from "http";
|
||||
import { WebSocket, WebSocketServer } from "ws";
|
||||
import { Socket } from "net";
|
||||
@@ -24,7 +25,8 @@ import {
|
||||
OlmSession,
|
||||
RemoteExitNode,
|
||||
RemoteExitNodeSession,
|
||||
remoteExitNodes
|
||||
remoteExitNodes,
|
||||
sites
|
||||
} from "@server/db";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { db } from "@server/db";
|
||||
@@ -57,11 +59,13 @@ const MAX_PENDING_MESSAGES = 50; // Maximum messages to queue during connection
|
||||
const processMessage = async (
|
||||
ws: AuthenticatedWebSocket,
|
||||
data: Buffer,
|
||||
isBinary: boolean,
|
||||
clientId: string,
|
||||
clientType: ClientType
|
||||
): Promise<void> => {
|
||||
try {
|
||||
const message: WSMessage = JSON.parse(data.toString());
|
||||
const messageBuffer = isBinary ? zlib.gunzipSync(data) : data;
|
||||
const message: WSMessage = JSON.parse(messageBuffer.toString());
|
||||
|
||||
// logger.debug(
|
||||
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
|
||||
@@ -76,7 +80,7 @@ const processMessage = async (
|
||||
clientId,
|
||||
message.type, // Pass message type for granular limiting
|
||||
100, // max requests per window
|
||||
20, // max requests per message type per window
|
||||
100, // max requests per message type per window
|
||||
60 * 1000 // window in milliseconds
|
||||
);
|
||||
if (rateLimitResult.isLimited) {
|
||||
@@ -163,8 +167,16 @@ const processPendingMessages = async (
|
||||
);
|
||||
|
||||
const jobs = [];
|
||||
for (const messageData of ws.pendingMessages) {
|
||||
jobs.push(processMessage(ws, messageData, clientId, clientType));
|
||||
for (const pending of ws.pendingMessages) {
|
||||
jobs.push(
|
||||
processMessage(
|
||||
ws,
|
||||
pending.data,
|
||||
pending.isBinary,
|
||||
clientId,
|
||||
clientType
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
await Promise.all(jobs);
|
||||
@@ -185,6 +197,12 @@ const connectedClients: Map<string, AuthenticatedWebSocket[]> = new Map();
|
||||
// Config version tracking map (local to this node, resets on server restart)
|
||||
const clientConfigVersions: Map<string, number> = new Map();
|
||||
|
||||
// Tracks the last Unix timestamp (seconds) at which a ping was flushed to the
|
||||
// DB for a given siteId. Resets on server restart which is fine – the first
|
||||
// ping after startup will always write, re-establishing the online state.
|
||||
const lastPingDbWrite: Map<number, number> = new Map();
|
||||
const PING_DB_WRITE_INTERVAL = 45; // seconds
|
||||
|
||||
// Recovery tracking
|
||||
let isRedisRecoveryInProgress = false;
|
||||
|
||||
@@ -325,7 +343,9 @@ const addClient = async (
|
||||
// Check Redis first if enabled
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
const redisVersion = await redisManager.get(getConfigVersionKey(clientId));
|
||||
const redisVersion = await redisManager.get(
|
||||
getConfigVersionKey(clientId)
|
||||
);
|
||||
if (redisVersion !== null) {
|
||||
configVersion = parseInt(redisVersion, 10);
|
||||
// Sync to local cache
|
||||
@@ -337,7 +357,10 @@ const addClient = async (
|
||||
} else {
|
||||
// Use local cache version and sync to Redis
|
||||
configVersion = clientConfigVersions.get(clientId) || 0;
|
||||
await redisManager.set(getConfigVersionKey(clientId), configVersion.toString());
|
||||
await redisManager.set(
|
||||
getConfigVersionKey(clientId),
|
||||
configVersion.toString()
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error("Failed to get/set config version in Redis:", error);
|
||||
@@ -432,7 +455,9 @@ const removeClient = async (
|
||||
};
|
||||
|
||||
// Helper to get the current config version for a client
|
||||
const getClientConfigVersion = async (clientId: string): Promise<number | undefined> => {
|
||||
const getClientConfigVersion = async (
|
||||
clientId: string
|
||||
): Promise<number | undefined> => {
|
||||
// Try Redis first if available
|
||||
if (redisManager.isRedisEnabled()) {
|
||||
try {
|
||||
@@ -502,11 +527,26 @@ const sendToClientLocal = async (
|
||||
};
|
||||
|
||||
const messageString = JSON.stringify(messageWithVersion);
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(messageString);
|
||||
}
|
||||
});
|
||||
if (options.compress) {
|
||||
logger.debug(
|
||||
`Message size before compression: ${messageString.length} bytes`
|
||||
);
|
||||
const compressed = zlib.gzipSync(Buffer.from(messageString, "utf8"));
|
||||
logger.debug(
|
||||
`Message size after compression: ${compressed.length} bytes`
|
||||
);
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(compressed);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(messageString);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
@@ -532,11 +572,22 @@ const broadcastToAllExceptLocal = async (
|
||||
configVersion
|
||||
};
|
||||
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(JSON.stringify(messageWithVersion));
|
||||
}
|
||||
});
|
||||
if (options.compress) {
|
||||
const compressed = zlib.gzipSync(
|
||||
Buffer.from(JSON.stringify(messageWithVersion), "utf8")
|
||||
);
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(compressed);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
clients.forEach((client) => {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(JSON.stringify(messageWithVersion));
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -762,7 +813,7 @@ const setupConnection = async (
|
||||
}
|
||||
|
||||
// Set up message handler FIRST to prevent race condition
|
||||
ws.on("message", async (data) => {
|
||||
ws.on("message", async (data, isBinary) => {
|
||||
if (!ws.isFullyConnected) {
|
||||
// Queue message for later processing with limits
|
||||
ws.pendingMessages = ws.pendingMessages || [];
|
||||
@@ -777,11 +828,17 @@ const setupConnection = async (
|
||||
logger.debug(
|
||||
`Queueing message from ${clientType.toUpperCase()} ID: ${clientId} (connection not fully established)`
|
||||
);
|
||||
ws.pendingMessages.push(data as Buffer);
|
||||
ws.pendingMessages.push({ data: data as Buffer, isBinary });
|
||||
return;
|
||||
}
|
||||
|
||||
await processMessage(ws, data as Buffer, clientId, clientType);
|
||||
await processMessage(
|
||||
ws,
|
||||
data as Buffer,
|
||||
isBinary,
|
||||
clientId,
|
||||
clientType
|
||||
);
|
||||
});
|
||||
|
||||
// Set up other event handlers before async operations
|
||||
@@ -796,6 +853,35 @@ const setupConnection = async (
|
||||
);
|
||||
});
|
||||
|
||||
// Handle WebSocket protocol-level pings from older newt clients that do
|
||||
// not send application-level "newt/ping" messages. Update the site's
|
||||
// online state and lastPing timestamp so the offline checker treats them
|
||||
// the same as modern newt clients.
|
||||
if (clientType === "newt") {
|
||||
const newtClient = client as Newt;
|
||||
ws.on("ping", async () => {
|
||||
if (!newtClient.siteId) return;
|
||||
const now = Math.floor(Date.now() / 1000);
|
||||
const lastWrite = lastPingDbWrite.get(newtClient.siteId) ?? 0;
|
||||
if (now - lastWrite < PING_DB_WRITE_INTERVAL) return;
|
||||
lastPingDbWrite.set(newtClient.siteId, now);
|
||||
try {
|
||||
await db
|
||||
.update(sites)
|
||||
.set({
|
||||
online: true,
|
||||
lastPing: now
|
||||
})
|
||||
.where(eq(sites.siteId, newtClient.siteId));
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
"Error updating newt site online state on WS ping",
|
||||
{ error }
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
ws.on("error", (error: Error) => {
|
||||
logger.error(
|
||||
`WebSocket error for ${clientType.toUpperCase()} ID ${clientId}:`,
|
||||
|
||||
Reference in New Issue
Block a user