Merge branch 'dev' into feat/login-page-customization

This commit is contained in:
miloschwartz
2025-12-17 11:41:17 -05:00
660 changed files with 19695 additions and 12803 deletions

View File

@@ -19,8 +19,14 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { queryAccessAuditLogsParams, queryAccessAuditLogsQuery, queryAccess } from "./queryAccessAuditLog";
import {
queryAccessAuditLogsParams,
queryAccessAuditLogsQuery,
queryAccess,
countAccessQuery
} from "./queryAccessAuditLog";
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
import { MAX_EXPORT_LIMIT } from "@server/routers/auditLogs";
registry.registerPath({
method: "get",
@@ -61,16 +67,28 @@ export async function exportAccessAuditLogs(
}
const data = { ...parsedQuery.data, ...parsedParams.data };
const [{ count }] = await countAccessQuery(data);
if (count > MAX_EXPORT_LIMIT) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Export limit exceeded. Your selection contains ${count} rows, but the maximum is ${MAX_EXPORT_LIMIT} rows. Please select a shorter time range to reduce the data.`
)
);
}
const baseQuery = queryAccess(data);
const log = await baseQuery.limit(data.limit).offset(data.offset);
const csvData = generateCSV(log);
res.setHeader('Content-Type', 'text/csv');
res.setHeader('Content-Disposition', `attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`);
res.setHeader("Content-Type", "text/csv");
res.setHeader(
"Content-Disposition",
`attachment; filename="access-audit-logs-${data.orgId}-${Date.now()}.csv"`
);
return res.send(csvData);
} catch (error) {
logger.error(error);
@@ -78,4 +96,4 @@ export async function exportAccessAuditLogs(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}
}

View File

@@ -19,8 +19,14 @@ import createHttpError from "http-errors";
import HttpCode from "@server/types/HttpCode";
import { fromError } from "zod-validation-error";
import logger from "@server/logger";
import { queryActionAuditLogsParams, queryActionAuditLogsQuery, queryAction } from "./queryActionAuditLog";
import {
queryActionAuditLogsParams,
queryActionAuditLogsQuery,
queryAction,
countActionQuery
} from "./queryActionAuditLog";
import { generateCSV } from "@server/routers/auditLogs/generateCSV";
import { MAX_EXPORT_LIMIT } from "@server/routers/auditLogs";
registry.registerPath({
method: "get",
@@ -60,17 +66,29 @@ export async function exportActionAuditLogs(
);
}
const data = { ...parsedQuery.data, ...parsedParams.data };
const data = { ...parsedQuery.data, ...parsedParams.data };
const [{ count }] = await countActionQuery(data);
if (count > MAX_EXPORT_LIMIT) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
`Export limit exceeded. Your selection contains ${count} rows, but the maximum is ${MAX_EXPORT_LIMIT} rows. Please select a shorter time range to reduce the data.`
)
);
}
const baseQuery = queryAction(data);
const log = await baseQuery.limit(data.limit).offset(data.offset);
const csvData = generateCSV(log);
res.setHeader('Content-Type', 'text/csv');
res.setHeader('Content-Disposition', `attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`);
res.setHeader("Content-Type", "text/csv");
res.setHeader(
"Content-Disposition",
`attachment; filename="action-audit-logs-${data.orgId}-${Date.now()}.csv"`
);
return res.send(csvData);
} catch (error) {
logger.error(error);
@@ -78,4 +96,4 @@ export async function exportActionAuditLogs(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An error occurred")
);
}
}
}

View File

@@ -14,4 +14,4 @@
export * from "./queryActionAuditLog";
export * from "./exportActionAuditLog";
export * from "./queryAccessAuditLog";
export * from "./exportAccessAuditLog";
export * from "./exportAccessAuditLog";

View File

@@ -24,6 +24,7 @@ import { fromError } from "zod-validation-error";
import { QueryAccessAuditLogResponse } from "@server/routers/auditLogs/types";
import response from "@server/lib/response";
import logger from "@server/logger";
import { getSevenDaysAgo } from "@app/lib/getSevenDaysAgo";
export const queryAccessAuditLogsQuery = z.object({
// iso string just validate its a parseable date
@@ -32,7 +33,14 @@ export const queryAccessAuditLogsQuery = z.object({
.refine((val) => !isNaN(Date.parse(val)), {
error: "timeStart must be a valid ISO date string"
})
.transform((val) => Math.floor(new Date(val).getTime() / 1000)),
.transform((val) => Math.floor(new Date(val).getTime() / 1000))
.prefault(() => getSevenDaysAgo().toISOString())
.openapi({
type: "string",
format: "date-time",
description:
"Start time as ISO date string (defaults to 7 days ago)"
}),
timeEnd: z
.string()
.refine((val) => !isNaN(Date.parse(val)), {
@@ -44,7 +52,8 @@ export const queryAccessAuditLogsQuery = z.object({
.openapi({
type: "string",
format: "date-time",
description: "End time as ISO date string (defaults to current time)"
description:
"End time as ISO date string (defaults to current time)"
}),
action: z
.union([z.boolean(), z.string()])
@@ -181,9 +190,15 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
return {
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
resources: uniqueResources.filter((row): row is { id: number; name: string | null } => row.id !== null),
locations: uniqueLocations.map(row => row.locations).filter((location): location is string => location !== null)
actors: uniqueActors
.map((row) => row.actor)
.filter((actor): actor is string => actor !== null),
resources: uniqueResources.filter(
(row): row is { id: number; name: string | null } => row.id !== null
),
locations: uniqueLocations
.map((row) => row.locations)
.filter((location): location is string => location !== null)
};
}

View File

@@ -24,6 +24,7 @@ import { fromError } from "zod-validation-error";
import { QueryActionAuditLogResponse } from "@server/routers/auditLogs/types";
import response from "@server/lib/response";
import logger from "@server/logger";
import { getSevenDaysAgo } from "@app/lib/getSevenDaysAgo";
export const queryActionAuditLogsQuery = z.object({
// iso string just validate its a parseable date
@@ -32,7 +33,14 @@ export const queryActionAuditLogsQuery = z.object({
.refine((val) => !isNaN(Date.parse(val)), {
error: "timeStart must be a valid ISO date string"
})
.transform((val) => Math.floor(new Date(val).getTime() / 1000)),
.transform((val) => Math.floor(new Date(val).getTime() / 1000))
.prefault(() => getSevenDaysAgo().toISOString())
.openapi({
type: "string",
format: "date-time",
description:
"Start time as ISO date string (defaults to 7 days ago)"
}),
timeEnd: z
.string()
.refine((val) => !isNaN(Date.parse(val)), {
@@ -44,7 +52,8 @@ export const queryActionAuditLogsQuery = z.object({
.openapi({
type: "string",
format: "date-time",
description: "End time as ISO date string (defaults to current time)"
description:
"End time as ISO date string (defaults to current time)"
}),
action: z.string().optional(),
actorType: z.string().optional(),
@@ -68,8 +77,9 @@ export const queryActionAuditLogsParams = z.object({
orgId: z.string()
});
export const queryActionAuditLogsCombined =
queryActionAuditLogsQuery.merge(queryActionAuditLogsParams);
export const queryActionAuditLogsCombined = queryActionAuditLogsQuery.merge(
queryActionAuditLogsParams
);
type Q = z.infer<typeof queryActionAuditLogsCombined>;
function getWhere(data: Q) {
@@ -78,7 +88,9 @@ function getWhere(data: Q) {
lt(actionAuditLog.timestamp, data.timeEnd),
eq(actionAuditLog.orgId, data.orgId),
data.actor ? eq(actionAuditLog.actor, data.actor) : undefined,
data.actorType ? eq(actionAuditLog.actorType, data.actorType) : undefined,
data.actorType
? eq(actionAuditLog.actorType, data.actorType)
: undefined,
data.actorId ? eq(actionAuditLog.actorId, data.actorId) : undefined,
data.action ? eq(actionAuditLog.action, data.action) : undefined
);
@@ -135,8 +147,12 @@ async function queryUniqueFilterAttributes(
.where(baseConditions);
return {
actors: uniqueActors.map(row => row.actor).filter((actor): actor is string => actor !== null),
actions: uniqueActions.map(row => row.action).filter((action): action is string => action !== null),
actors: uniqueActors
.map((row) => row.actor)
.filter((actor): actor is string => actor !== null),
actions: uniqueActions
.map((row) => row.action)
.filter((action): action is string => action !== null)
};
}

View File

@@ -13,4 +13,4 @@
export * from "./transferSession";
export * from "./getSessionTransferToken";
export * from "./quickStart";
export * from "./quickStart";

View File

@@ -395,7 +395,8 @@ export async function quickStart(
.values({
targetId: newTarget[0].targetId,
hcEnabled: false
}).returning();
})
.returning();
// add the new target to the targetIps array
targetIps.push(`${ip}/32`);
@@ -406,7 +407,12 @@ export async function quickStart(
.where(eq(newts.siteId, siteId!))
.limit(1);
await addTargets(newt.newtId, newTarget, newHealthcheck, resource.protocol);
await addTargets(
newt.newtId,
newTarget,
newHealthcheck,
resource.protocol
);
// Set resource pincode if provided
if (pincode) {

View File

@@ -26,8 +26,8 @@ import { getLineItems, getStandardFeaturePriceSet } from "@server/lib/billing";
import { getTierPriceSet, TierId } from "@server/lib/billing/tiers";
const createCheckoutSessionSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function createCheckoutSession(
req: Request,
@@ -72,7 +72,7 @@ export async function createCheckoutSession(
billing_address_collection: "required",
line_items: [
{
price: standardTierPrice, // Use the standard tier
price: standardTierPrice, // Use the standard tier
quantity: 1
},
...getLineItems(getStandardFeaturePriceSet())

View File

@@ -24,8 +24,8 @@ import { fromError } from "zod-validation-error";
import stripe from "#private/lib/stripe";
const createPortalSessionSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function createPortalSession(
req: Request,

View File

@@ -34,8 +34,8 @@ import {
} from "@server/db";
const getOrgSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
registry.registerPath({
method: "get",

View File

@@ -28,8 +28,8 @@ import { FeatureId } from "@server/lib/billing";
import { GetOrgUsageResponse } from "@server/routers/billing/types";
const getOrgSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
registry.registerPath({
method: "get",
@@ -78,11 +78,23 @@ export async function getOrgUsage(
// Get usage for org
const usageData = [];
const siteUptime = await usageService.getUsage(orgId, FeatureId.SITE_UPTIME);
const siteUptime = await usageService.getUsage(
orgId,
FeatureId.SITE_UPTIME
);
const users = await usageService.getUsageDaily(orgId, FeatureId.USERS);
const domains = await usageService.getUsageDaily(orgId, FeatureId.DOMAINS);
const remoteExitNodes = await usageService.getUsageDaily(orgId, FeatureId.REMOTE_EXIT_NODES);
const egressData = await usageService.getUsage(orgId, FeatureId.EGRESS_DATA_MB);
const domains = await usageService.getUsageDaily(
orgId,
FeatureId.DOMAINS
);
const remoteExitNodes = await usageService.getUsageDaily(
orgId,
FeatureId.REMOTE_EXIT_NODES
);
const egressData = await usageService.getUsage(
orgId,
FeatureId.EGRESS_DATA_MB
);
if (siteUptime) {
usageData.push(siteUptime);
@@ -100,7 +112,8 @@ export async function getOrgUsage(
usageData.push(remoteExitNodes);
}
const orgLimits = await db.select()
const orgLimits = await db
.select()
.from(limits)
.where(eq(limits.orgId, orgId));

View File

@@ -31,9 +31,7 @@ export async function handleCustomerDeleted(
return;
}
await db
.delete(customers)
.where(eq(customers.customerId, customer.id));
await db.delete(customers).where(eq(customers.customerId, customer.id));
} catch (error) {
logger.error(
`Error handling customer created event for ID ${customer.id}:`,

View File

@@ -12,7 +12,14 @@
*/
import Stripe from "stripe";
import { subscriptions, db, subscriptionItems, customers, userOrgs, users } from "@server/db";
import {
subscriptions,
db,
subscriptionItems,
customers,
userOrgs,
users
} from "@server/db";
import { eq, and } from "drizzle-orm";
import logger from "@server/logger";
import { handleSubscriptionLifesycle } from "../subscriptionLifecycle";
@@ -43,7 +50,6 @@ export async function handleSubscriptionDeleted(
.delete(subscriptionItems)
.where(eq(subscriptionItems.subscriptionId, subscription.id));
// Lookup customer to get orgId
const [customer] = await db
.select()
@@ -58,10 +64,7 @@ export async function handleSubscriptionDeleted(
return;
}
await handleSubscriptionLifesycle(
customer.orgId,
subscription.status
);
await handleSubscriptionLifesycle(customer.orgId, subscription.status);
const [orgUserRes] = await db
.select()

View File

@@ -15,4 +15,4 @@ export * from "./createCheckoutSession";
export * from "./createPortalSession";
export * from "./getOrgSubscription";
export * from "./getOrgUsage";
export * from "./internalGetOrgTier";
export * from "./internalGetOrgTier";

View File

@@ -22,8 +22,8 @@ import { getOrgTierData } from "#private/lib/billing";
import { GetOrgTierResponse } from "@server/routers/billing/types";
const getOrgSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function getOrgTier(
req: Request,

View File

@@ -11,11 +11,18 @@
* This file is not licensed under the AGPLv3.
*/
import { freeLimitSet, limitsService, subscribedLimitSet } from "@server/lib/billing";
import {
freeLimitSet,
limitsService,
subscribedLimitSet
} from "@server/lib/billing";
import { usageService } from "@server/lib/billing/usageService";
import logger from "@server/logger";
export async function handleSubscriptionLifesycle(orgId: string, status: string) {
export async function handleSubscriptionLifesycle(
orgId: string,
status: string
) {
switch (status) {
case "active":
await limitsService.applyLimitSetToOrg(orgId, subscribedLimitSet);
@@ -42,4 +49,4 @@ export async function handleSubscriptionLifesycle(orgId: string, status: string)
default:
break;
}
}
}

View File

@@ -32,12 +32,13 @@ export async function billingWebhookHandler(
next: NextFunction
): Promise<any> {
let event: Stripe.Event = req.body;
const endpointSecret = privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
const endpointSecret =
privateConfig.getRawPrivateConfig().stripe?.webhook_secret;
if (!endpointSecret) {
logger.warn("Stripe webhook secret is not configured. Webhook events will not be priocessed.");
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "")
logger.warn(
"Stripe webhook secret is not configured. Webhook events will not be priocessed."
);
return next(createHttpError(HttpCode.INTERNAL_SERVER_ERROR, ""));
}
// Only verify the event if you have an endpoint secret defined.
@@ -49,7 +50,10 @@ export async function billingWebhookHandler(
if (!signature) {
logger.info("No stripe signature found in headers.");
return next(
createHttpError(HttpCode.BAD_REQUEST, "No stripe signature found in headers")
createHttpError(
HttpCode.BAD_REQUEST,
"No stripe signature found in headers"
)
);
}
@@ -62,7 +66,10 @@ export async function billingWebhookHandler(
} catch (err) {
logger.error(`Webhook signature verification failed.`, err);
return next(
createHttpError(HttpCode.UNAUTHORIZED, "Webhook signature verification failed")
createHttpError(
HttpCode.UNAUTHORIZED,
"Webhook signature verification failed"
)
);
}
}

View File

@@ -24,10 +24,10 @@ import { registry } from "@server/openApi";
import { GetCertificateResponse } from "@server/routers/certificates/types";
const getCertificateSchema = z.strictObject({
domainId: z.string(),
domain: z.string().min(1).max(255),
orgId: z.string()
});
domainId: z.string(),
domain: z.string().min(1).max(255),
orgId: z.string()
});
async function query(domainId: string, domain: string) {
const [domainRecord] = await db
@@ -42,8 +42,8 @@ async function query(domainId: string, domain: string) {
let existing: any[] = [];
if (domainRecord.type == "ns") {
const domainLevelDown = domain.split('.').slice(1).join('.');
const domainLevelDown = domain.split(".").slice(1).join(".");
existing = await db
.select({
certId: certificates.certId,
@@ -64,7 +64,7 @@ async function query(domainId: string, domain: string) {
eq(certificates.wildcard, true), // only NS domains can have wildcard certs
or(
eq(certificates.domain, domain),
eq(certificates.domain, domainLevelDown),
eq(certificates.domain, domainLevelDown)
)
)
);
@@ -102,8 +102,7 @@ registry.registerPath({
tags: ["Certificate"],
request: {
params: z.object({
domainId: z
.string(),
domainId: z.string(),
domain: z.string().min(1).max(255),
orgId: z.string()
})
@@ -133,7 +132,9 @@ export async function getCertificate(
if (!cert) {
logger.warn(`Certificate not found for domain: ${domainId}`);
return next(createHttpError(HttpCode.NOT_FOUND, "Certificate not found"));
return next(
createHttpError(HttpCode.NOT_FOUND, "Certificate not found")
);
}
return response<GetCertificateResponse>(res, {

View File

@@ -12,4 +12,4 @@
*/
export * from "./getCertificate";
export * from "./restartCertificate";
export * from "./restartCertificate";

View File

@@ -25,9 +25,9 @@ import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
const restartCertificateParamsSchema = z.strictObject({
certId: z.string().transform(stoi).pipe(z.int().positive()),
orgId: z.string()
});
certId: z.string().transform(stoi).pipe(z.int().positive()),
orgId: z.string()
});
registry.registerPath({
method: "post",
@@ -36,10 +36,7 @@ registry.registerPath({
tags: ["Certificate"],
request: {
params: z.object({
certId: z
.string()
.transform(stoi)
.pipe(z.int().positive()),
certId: z.string().transform(stoi).pipe(z.int().positive()),
orgId: z.string()
})
},
@@ -94,7 +91,7 @@ export async function restartCertificate(
.set({
status: "pending",
errorMessage: null,
lastRenewalAttempt: Math.floor(Date.now() / 1000)
lastRenewalAttempt: Math.floor(Date.now() / 1000)
})
.where(eq(certificates.certId, certId));

View File

@@ -26,8 +26,8 @@ import { CheckDomainAvailabilityResponse } from "@server/routers/domain/types";
const paramsSchema = z.strictObject({});
const querySchema = z.strictObject({
subdomain: z.string()
});
subdomain: z.string()
});
registry.registerPath({
method: "get",

View File

@@ -12,4 +12,4 @@
*/
export * from "./checkDomainNamespaceAvailability";
export * from "./listDomainNamespaces";
export * from "./listDomainNamespaces";

View File

@@ -26,19 +26,19 @@ import { OpenAPITags, registry } from "@server/openApi";
const paramsSchema = z.strictObject({});
const querySchema = z.strictObject({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
async function query(limit: number, offset: number) {
const res = await db

View File

@@ -30,7 +30,7 @@ import {
verifyUserHasAction,
verifyUserIsServerAdmin,
verifySiteAccess,
verifyClientAccess,
verifyClientAccess
} from "@server/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import {
@@ -436,6 +436,8 @@ authenticated.get(
authenticated.post(
"/re-key/:clientId/regenerate-client-secret",
verifyValidLicense,
verifyValidSubscription,
verifyClientAccess,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateClientSecret
@@ -443,15 +445,18 @@ authenticated.post(
authenticated.post(
"/re-key/:siteId/regenerate-site-secret",
verifyValidLicense,
verifyValidSubscription,
verifySiteAccess,
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateSiteSecret
);
authenticated.put(
"/re-key/:orgId/reGenerate-remote-exit-node-secret",
"/re-key/:orgId/regenerate-remote-exit-node-secret",
verifyValidLicense,
verifyValidSubscription,
verifyOrgAccess,
verifyUserHasAction(ActionsEnum.updateRemoteExitNode),
verifyUserHasAction(ActionsEnum.reGenerateSecret),
reKey.reGenerateExitNodeSecret
);

View File

@@ -1,13 +0,0 @@
/*
* This file is part of a proprietary work.
*
* Copyright (c) 2025 Fossorial, Inc.
* All rights reserved.
*
* This file is licensed under the Fossorial Commercial License.
* You may not use this file except in compliance with the License.
* Unauthorized use, copying, modification, or distribution is strictly prohibited.
*
* This file is not licensed under the AGPLv3.
*/

View File

@@ -79,86 +79,72 @@ import semver from "semver";
// Zod schemas for request validation
const getResourceByDomainParamsSchema = z.strictObject({
domain: z.string().min(1, "Domain is required")
});
domain: z.string().min(1, "Domain is required")
});
const getUserSessionParamsSchema = z.strictObject({
userSessionId: z.string().min(1, "User session ID is required")
});
userSessionId: z.string().min(1, "User session ID is required")
});
const getUserOrgRoleParamsSchema = z.strictObject({
userId: z.string().min(1, "User ID is required"),
orgId: z.string().min(1, "Organization ID is required")
});
userId: z.string().min(1, "User ID is required"),
orgId: z.string().min(1, "Organization ID is required")
});
const getRoleResourceAccessParamsSchema = z.strictObject({
roleId: z
.string()
.transform(Number)
.pipe(
z.int().positive("Role ID must be a positive integer")
),
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
roleId: z
.string()
.transform(Number)
.pipe(z.int().positive("Role ID must be a positive integer")),
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const getUserResourceAccessParamsSchema = z.strictObject({
userId: z.string().min(1, "User ID is required"),
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
userId: z.string().min(1, "User ID is required"),
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const getResourceRulesParamsSchema = z.strictObject({
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const validateResourceSessionTokenParamsSchema = z.strictObject({
resourceId: z
.string()
.transform(Number)
.pipe(
z.int()
.positive("Resource ID must be a positive integer")
)
});
resourceId: z
.string()
.transform(Number)
.pipe(z.int().positive("Resource ID must be a positive integer"))
});
const validateResourceSessionTokenBodySchema = z.strictObject({
token: z.string().min(1, "Token is required")
});
token: z.string().min(1, "Token is required")
});
const validateResourceAccessTokenBodySchema = z.strictObject({
accessTokenId: z.string().optional(),
resourceId: z.number().optional(),
accessToken: z.string()
});
accessTokenId: z.string().optional(),
resourceId: z.number().optional(),
accessToken: z.string()
});
// Certificates by domains query validation
const getCertificatesByDomainsQuerySchema = z.strictObject({
// Accept domains as string or array (domains or domains[])
domains: z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional(),
// Handle array format from query parameters (domains[])
"domains[]": z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional()
});
// Accept domains as string or array (domains or domains[])
domains: z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional(),
// Handle array format from query parameters (domains[])
"domains[]": z
.union([z.array(z.string().min(1)), z.string().min(1)])
.optional()
});
// Type exports for request schemas
export type GetResourceByDomainParams = z.infer<
@@ -566,8 +552,8 @@ hybridRouter.get(
);
const getOrgLoginPageParamsSchema = z.strictObject({
orgId: z.string().min(1)
});
orgId: z.string().min(1)
});
hybridRouter.get(
"/org/:orgId/login-page",
@@ -1408,8 +1394,16 @@ hybridRouter.post(
);
}
const { olmId, newtId, ip, port, timestamp, token, publicKey, reachableAt } =
parsedParams.data;
const {
olmId,
newtId,
ip,
port,
timestamp,
token,
publicKey,
reachableAt
} = parsedParams.data;
const destinations = await updateAndGenerateEndpointDestinations(
olmId,

View File

@@ -18,7 +18,7 @@ import * as logs from "#private/routers/auditLogs";
import {
verifyApiKeyHasAction,
verifyApiKeyIsRoot,
verifyApiKeyOrgAccess,
verifyApiKeyOrgAccess
} from "@server/middlewares";
import {
verifyValidSubscription,
@@ -26,7 +26,10 @@ import {
} from "#private/middlewares";
import { ActionsEnum } from "@server/auth/actions";
import { unauthenticated as ua, authenticated as a } from "@server/routers/integration";
import {
unauthenticated as ua,
authenticated as a
} from "@server/routers/integration";
import { logActionAudit } from "#private/middlewares";
export const unauthenticated = ua;
@@ -37,7 +40,7 @@ authenticated.post(
verifyApiKeyIsRoot, // We are the only ones who can use root key so its fine
verifyApiKeyHasAction(ActionsEnum.sendUsageNotification),
logActionAudit(ActionsEnum.sendUsageNotification),
org.sendUsageNotification,
org.sendUsageNotification
);
authenticated.delete(
@@ -45,7 +48,7 @@ authenticated.delete(
verifyApiKeyIsRoot,
verifyApiKeyHasAction(ActionsEnum.deleteIdp),
logActionAudit(ActionsEnum.deleteIdp),
orgIdp.deleteOrgIdp,
orgIdp.deleteOrgIdp
);
authenticated.get(

View File

@@ -21,8 +21,8 @@ import { z } from "zod";
import { fromError } from "zod-validation-error";
const bodySchema = z.strictObject({
licenseKey: z.string().min(1).max(255)
});
licenseKey: z.string().min(1).max(255)
});
export async function activateLicense(
req: Request,

View File

@@ -24,8 +24,8 @@ import { licenseKey } from "@server/db";
import license from "#private/license/license";
const paramsSchema = z.strictObject({
licenseKey: z.string().min(1).max(255)
});
licenseKey: z.string().min(1).max(255)
});
export async function deleteLicenseKey(
req: Request,

View File

@@ -36,13 +36,13 @@ import { build } from "@server/build";
import { CreateLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
const bodySchema = z.strictObject({
subdomain: z.string().nullable().optional(),
domainId: z.string()
});
subdomain: z.string().nullable().optional(),
domainId: z.string()
});
export type CreateLoginPageBody = z.infer<typeof bodySchema>;
@@ -149,12 +149,20 @@ export async function createLoginPage(
let returned: LoginPage | undefined;
await db.transaction(async (trx) => {
const orgSites = await trx
.select()
.from(sites)
.innerJoin(exitNodes, eq(exitNodes.exitNodeId, sites.exitNodeId))
.where(and(eq(sites.orgId, orgId), eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
.innerJoin(
exitNodes,
eq(exitNodes.exitNodeId, sites.exitNodeId)
)
.where(
and(
eq(sites.orgId, orgId),
eq(exitNodes.type, "gerbil"),
eq(exitNodes.online, true)
)
)
.limit(10);
let exitNodesList = orgSites.map((s) => s.exitNodes);
@@ -163,7 +171,12 @@ export async function createLoginPage(
exitNodesList = await trx
.select()
.from(exitNodes)
.where(and(eq(exitNodes.type, "gerbil"), eq(exitNodes.online, true)))
.where(
and(
eq(exitNodes.type, "gerbil"),
eq(exitNodes.online, true)
)
)
.limit(10);
}

View File

@@ -78,15 +78,11 @@ export async function deleteLoginPage(
// if (!leftoverLinks.length) {
await db
.delete(loginPage)
.where(
eq(loginPage.loginPageId, parsedParams.data.loginPageId)
);
.where(eq(loginPage.loginPageId, parsedParams.data.loginPageId));
await db
.delete(loginPageOrg)
.where(
eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId)
);
.where(eq(loginPageOrg.loginPageId, parsedParams.data.loginPageId));
// }
return response<LoginPage>(res, {

View File

@@ -23,8 +23,8 @@ import { fromError } from "zod-validation-error";
import { GetLoginPageResponse } from "@server/routers/loginPage/types";
const paramsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
async function query(orgId: string) {
const [res] = await db

View File

@@ -35,7 +35,8 @@ const paramsSchema = z
})
.strict();
const bodySchema = z.strictObject({
const bodySchema = z
.strictObject({
subdomain: subdomainSchema.nullable().optional(),
domainId: z.string().optional()
})
@@ -86,7 +87,7 @@ export async function updateLoginPage(
const { loginPageId, orgId } = parsedParams.data;
if (build === "saas"){
if (build === "saas") {
const { tier } = await getOrgTierData(orgId);
const subscribed = tier === TierId.STANDARD;
if (!subscribed) {
@@ -182,7 +183,10 @@ export async function updateLoginPage(
}
// update the full domain if it has changed
if (fullDomain && fullDomain !== existingLoginPage?.fullDomain) {
if (
fullDomain &&
fullDomain !== existingLoginPage?.fullDomain
) {
await db
.update(loginPage)
.set({ fullDomain })

View File

@@ -23,9 +23,9 @@ import SupportEmail from "@server/emails/templates/SupportEmail";
import config from "@server/lib/config";
const bodySchema = z.strictObject({
body: z.string().min(1),
subject: z.string().min(1).max(255)
});
body: z.string().min(1),
subject: z.string().min(1).max(255)
});
export async function sendSupportEmail(
req: Request,
@@ -66,6 +66,7 @@ export async function sendSupportEmail(
{
name: req.user?.email || "Support User",
to: "support@pangolin.net",
replyTo: req.user?.email || undefined,
from: config.getNoReplyEmail(),
subject: `Support Request: ${subject}`
}

View File

@@ -11,4 +11,4 @@
* This file is not licensed under the AGPLv3.
*/
export * from "./sendUsageNotifications";
export * from "./sendUsageNotifications";

View File

@@ -35,10 +35,12 @@ const sendUsageNotificationBodySchema = z.object({
notificationType: z.enum(["approaching_70", "approaching_90", "reached"]),
limitName: z.string(),
currentUsage: z.number(),
usageLimit: z.number(),
usageLimit: z.number()
});
type SendUsageNotificationRequest = z.infer<typeof sendUsageNotificationBodySchema>;
type SendUsageNotificationRequest = z.infer<
typeof sendUsageNotificationBodySchema
>;
export type SendUsageNotificationResponse = {
success: boolean;
@@ -97,17 +99,13 @@ async function getOrgAdmins(orgId: string) {
.where(
and(
eq(userOrgs.orgId, orgId),
or(
eq(userOrgs.isOwner, true),
eq(roles.isAdmin, true)
)
or(eq(userOrgs.isOwner, true), eq(roles.isAdmin, true))
)
);
// Filter to only include users with verified emails
const orgAdmins = admins.filter(admin =>
admin.email &&
admin.email.length > 0
const orgAdmins = admins.filter(
(admin) => admin.email && admin.email.length > 0
);
return orgAdmins;
@@ -119,7 +117,9 @@ export async function sendUsageNotification(
next: NextFunction
): Promise<any> {
try {
const parsedParams = sendUsageNotificationParamsSchema.safeParse(req.params);
const parsedParams = sendUsageNotificationParamsSchema.safeParse(
req.params
);
if (!parsedParams.success) {
return next(
createHttpError(
@@ -140,12 +140,8 @@ export async function sendUsageNotification(
}
const { orgId } = parsedParams.data;
const {
notificationType,
limitName,
currentUsage,
usageLimit,
} = parsedBody.data;
const { notificationType, limitName, currentUsage, usageLimit } =
parsedBody.data;
// Verify organization exists
const org = await db
@@ -192,7 +188,10 @@ export async function sendUsageNotification(
let template;
let subject;
if (notificationType === "approaching_70" || notificationType === "approaching_90") {
if (
notificationType === "approaching_70" ||
notificationType === "approaching_90"
) {
template = NotifyUsageLimitApproaching({
email: admin.email,
limitName,
@@ -220,10 +219,15 @@ export async function sendUsageNotification(
emailsSent++;
adminEmails.push(admin.email);
logger.info(`Usage notification sent to admin ${admin.email} for org ${orgId}`);
logger.info(
`Usage notification sent to admin ${admin.email} for org ${orgId}`
);
} catch (emailError) {
logger.error(`Failed to send usage notification to ${admin.email}:`, emailError);
logger.error(
`Failed to send usage notification to ${admin.email}:`,
emailError
);
// Continue with other admins even if one fails
}
}
@@ -239,11 +243,13 @@ export async function sendUsageNotification(
message: `Usage notifications sent to ${emailsSent} administrators`,
status: HttpCode.OK
});
} catch (error) {
logger.error("Error sending usage notifications:", error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "Failed to send usage notifications")
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"Failed to send usage notifications"
)
);
}
}

View File

@@ -32,19 +32,19 @@ import { CreateOrgIdpResponse } from "@server/routers/orgIdp/types";
const paramsSchema = z.strictObject({ orgId: z.string().nonempty() });
const bodySchema = z.strictObject({
name: z.string().nonempty(),
clientId: z.string().nonempty(),
clientSecret: z.string().nonempty(),
authUrl: z.url(),
tokenUrl: z.url(),
identifierPath: z.string().nonempty(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().nonempty(),
autoProvision: z.boolean().optional(),
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
roleMapping: z.string().optional()
});
name: z.string().nonempty(),
clientId: z.string().nonempty(),
clientSecret: z.string().nonempty(),
authUrl: z.url(),
tokenUrl: z.url(),
identifierPath: z.string().nonempty(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().nonempty(),
autoProvision: z.boolean().optional(),
variant: z.enum(["oidc", "google", "azure"]).optional().default("oidc"),
roleMapping: z.string().optional()
});
// registry.registerPath({
// method: "put",
@@ -158,7 +158,10 @@ export async function createOrgOidcIdp(
});
});
const redirectUrl = await generateOidcRedirectUrl(idpId as number, orgId);
const redirectUrl = await generateOidcRedirectUrl(
idpId as number,
orgId
);
return response<CreateOrgIdpResponse>(res, {
data: {

View File

@@ -66,12 +66,7 @@ export async function deleteOrgIdp(
.where(eq(idp.idpId, idpId));
if (!existingIdp) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
"IdP not found"
)
);
return next(createHttpError(HttpCode.NOT_FOUND, "IdP not found"));
}
// Delete the IDP and its related records in a transaction
@@ -82,14 +77,10 @@ export async function deleteOrgIdp(
.where(eq(idpOidcConfig.idpId, idpId));
// Delete IDP-org mappings
await trx
.delete(idpOrg)
.where(eq(idpOrg.idpId, idpId));
await trx.delete(idpOrg).where(eq(idpOrg.idpId, idpId));
// Delete the IDP itself
await trx
.delete(idp)
.where(eq(idp.idpId, idpId));
await trx.delete(idp).where(eq(idp.idpId, idpId));
});
return response<null>(res, {

View File

@@ -93,7 +93,10 @@ export async function getOrgIdp(
idpRes.idpOidcConfig!.clientId = decrypt(clientId, key);
}
const redirectUrl = await generateOidcRedirectUrl(idpRes.idp.idpId, orgId);
const redirectUrl = await generateOidcRedirectUrl(
idpRes.idp.idpId,
orgId
);
return response<GetOrgIdpResponse>(res, {
data: {

View File

@@ -15,4 +15,4 @@ export * from "./createOrgOidcIdp";
export * from "./getOrgIdp";
export * from "./listOrgIdps";
export * from "./updateOrgOidcIdp";
export * from "./deleteOrgIdp";
export * from "./deleteOrgIdp";

View File

@@ -25,23 +25,23 @@ import { OpenAPITags, registry } from "@server/openApi";
import { ListOrgIdpsResponse } from "@server/routers/orgIdp/types";
const querySchema = z.strictObject({
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
limit: z
.string()
.optional()
.default("1000")
.transform(Number)
.pipe(z.int().nonnegative()),
offset: z
.string()
.optional()
.default("0")
.transform(Number)
.pipe(z.int().nonnegative())
});
const paramsSchema = z.strictObject({
orgId: z.string().nonempty()
});
orgId: z.string().nonempty()
});
async function query(orgId: string, limit: number, offset: number) {
const res = await db

View File

@@ -36,18 +36,18 @@ const paramsSchema = z
.strict();
const bodySchema = z.strictObject({
name: z.string().optional(),
clientId: z.string().optional(),
clientSecret: z.string().optional(),
authUrl: z.string().optional(),
tokenUrl: z.string().optional(),
identifierPath: z.string().optional(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().optional(),
autoProvision: z.boolean().optional(),
roleMapping: z.string().optional()
});
name: z.string().optional(),
clientId: z.string().optional(),
clientSecret: z.string().optional(),
authUrl: z.string().optional(),
tokenUrl: z.string().optional(),
identifierPath: z.string().optional(),
emailPath: z.string().optional(),
namePath: z.string().optional(),
scopes: z.string().optional(),
autoProvision: z.boolean().optional(),
roleMapping: z.string().optional()
});
export type UpdateOrgIdpResponse = {
idpId: number;

View File

@@ -13,4 +13,4 @@
export * from "./reGenerateClientSecret";
export * from "./reGenerateSiteSecret";
export * from "./reGenerateExitNodeSecret";
export * from "./reGenerateExitNodeSecret";

View File

@@ -13,7 +13,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, olms, } from "@server/db";
import { db, Olm, olms } from "@server/db";
import { clients } from "@server/db";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
@@ -23,37 +23,19 @@ import { eq, and } from "drizzle-orm";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { hashPassword } from "@server/auth/password";
import { disconnectClient, sendToClient } from "#private/routers/ws";
const reGenerateSecretParamsSchema = z.strictObject({
clientId: z.string().transform(Number).pipe(z.int().positive())
});
const reGenerateSecretBodySchema = z.strictObject({
olmId: z.string().min(1).optional(),
secret: z.string().min(1).optional(),
});
export type ReGenerateSecretBody = z.infer<typeof reGenerateSecretBodySchema>;
registry.registerPath({
method: "post",
path: "/re-key/{clientId}/regenerate-client-secret",
description: "Regenerate a client's OLM credentials by its client ID.",
tags: [OpenAPITags.Client],
request: {
params: reGenerateSecretParamsSchema,
body: {
content: {
"application/json": {
schema: reGenerateSecretBodySchema
}
}
}
},
responses: {}
clientId: z.string().transform(Number).pipe(z.int().positive())
});
const reGenerateSecretBodySchema = z.strictObject({
// olmId: z.string().min(1).optional(),
secret: z.string().min(1),
disconnect: z.boolean().optional().default(true)
});
export type ReGenerateSecretBody = z.infer<typeof reGenerateSecretBodySchema>;
export async function reGenerateClientSecret(
req: Request,
@@ -71,7 +53,7 @@ export async function reGenerateClientSecret(
);
}
const { olmId, secret } = parsedBody.data;
const { secret, disconnect } = parsedBody.data;
const parsedParams = reGenerateSecretParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
@@ -85,11 +67,7 @@ export async function reGenerateClientSecret(
const { clientId } = parsedParams.data;
let secretHash = undefined;
if (secret) {
secretHash = await hashPassword(secret);
}
const secretHash = await hashPassword(secret);
// Fetch the client to make sure it exists and the user has access to it
const [client] = await db
@@ -107,24 +85,59 @@ export async function reGenerateClientSecret(
);
}
const [existingOlm] = await db
const existingOlms = await db
.select()
.from(olms)
.where(eq(olms.clientId, clientId))
.limit(1);
.where(eq(olms.clientId, clientId));
if (existingOlm && olmId && secretHash) {
await db
.update(olms)
.set({
olmId,
secretHash
})
.where(eq(olms.clientId, clientId));
if (existingOlms.length === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`No OLM found for client ID ${clientId}`
)
);
}
if (existingOlms.length > 1) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Multiple OLM entries found for client ID ${clientId}`
)
);
}
await db
.update(olms)
.set({
secretHash
})
.where(eq(olms.olmId, existingOlms[0].olmId));
// Only disconnect if explicitly requested
if (disconnect) {
const payload = {
type: `olm/terminate`,
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(existingOlms[0].olmId, payload).catch((error) => {
logger.error(
"Failed to send termination message to olm:",
error
);
});
disconnectClient(existingOlms[0].olmId).catch((error) => {
logger.error("Failed to disconnect olm after re-key:", error);
});
}
return response(res, {
data: existingOlm,
data: {
olmId: existingOlms[0].olmId
},
success: true,
error: false,
message: "Credentials regenerated successfully",

View File

@@ -12,7 +12,14 @@
*/
import { NextFunction, Request, Response } from "express";
import { db, exitNodes, exitNodeOrgs, ExitNode, ExitNodeOrg } from "@server/db";
import {
db,
exitNodes,
exitNodeOrgs,
ExitNode,
ExitNodeOrg,
RemoteExitNode
} from "@server/db";
import HttpCode from "@server/types/HttpCode";
import { z } from "zod";
import { remoteExitNodes } from "@server/db";
@@ -22,35 +29,17 @@ import { fromError } from "zod-validation-error";
import { hashPassword } from "@server/auth/password";
import logger from "@server/logger";
import { and, eq } from "drizzle-orm";
import { UpdateRemoteExitNodeResponse } from "@server/routers/remoteExitNode/types";
import { OpenAPITags, registry } from "@server/openApi";
import { disconnectClient, sendToClient } from "#private/routers/ws";
export const paramsSchema = z.object({
orgId: z.string()
});
const bodySchema = z.strictObject({
remoteExitNodeId: z.string().length(15),
secret: z.string().length(48)
});
registry.registerPath({
method: "post",
path: "/re-key/{orgId}/regenerate-secret",
description: "Regenerate a exit node credentials by its org ID.",
tags: [OpenAPITags.Org],
request: {
params: paramsSchema,
body: {
content: {
"application/json": {
schema: bodySchema
}
}
}
},
responses: {}
remoteExitNodeId: z.string().length(15),
secret: z.string().length(48),
disconnect: z.boolean().optional().default(true)
});
export async function reGenerateExitNodeSecret(
@@ -79,13 +68,7 @@ export async function reGenerateExitNodeSecret(
);
}
const { remoteExitNodeId, secret } = parsedBody.data;
if (req.user && !req.userOrgRoleId) {
return next(
createHttpError(HttpCode.FORBIDDEN, "User does not have a role")
);
}
const { remoteExitNodeId, secret, disconnect } = parsedBody.data;
const [existingRemoteExitNode] = await db
.select()
@@ -94,7 +77,10 @@ export async function reGenerateExitNodeSecret(
if (!existingRemoteExitNode) {
return next(
createHttpError(HttpCode.NOT_FOUND, "Remote Exit Node does not exist")
createHttpError(
HttpCode.NOT_FOUND,
"Remote Exit Node does not exist"
)
);
}
@@ -105,15 +91,39 @@ export async function reGenerateExitNodeSecret(
.set({ secretHash })
.where(eq(remoteExitNodes.remoteExitNodeId, remoteExitNodeId));
return response<UpdateRemoteExitNodeResponse>(res, {
data: {
remoteExitNodeId,
secret,
},
// Only disconnect if explicitly requested
if (disconnect) {
const payload = {
type: `remoteExitNode/terminate`,
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(
existingRemoteExitNode.remoteExitNodeId,
payload
).catch((error) => {
logger.error(
"Failed to send termination message to remote exit node:",
error
);
});
disconnectClient(existingRemoteExitNode.remoteExitNodeId).catch(
(error) => {
logger.error(
"Failed to disconnect remote exit node after re-key:",
error
);
}
);
}
return response(res, {
data: null,
success: true,
error: false,
message: "Remote Exit Node secret updated successfully",
status: HttpCode.OK,
status: HttpCode.OK
});
} catch (e) {
logger.error("Failed to update remoteExitNode", e);

View File

@@ -13,7 +13,7 @@
import { Request, Response, NextFunction } from "express";
import { z } from "zod";
import { db, newts, sites } from "@server/db";
import { db, Newt, newts, sites } from "@server/db";
import { eq } from "drizzle-orm";
import response from "@server/lib/response";
import HttpCode from "@server/types/HttpCode";
@@ -22,38 +22,19 @@ import logger from "@server/logger";
import { fromError } from "zod-validation-error";
import { OpenAPITags, registry } from "@server/openApi";
import { hashPassword } from "@server/auth/password";
import { addPeer } from "@server/routers/gerbil/peers";
import { addPeer, deletePeer } from "@server/routers/gerbil/peers";
import { getAllowedIps } from "@server/routers/target/helpers";
import { disconnectClient, sendToClient } from "#private/routers/ws";
const updateSiteParamsSchema = z.strictObject({
siteId: z.string().transform(Number).pipe(z.int().positive())
});
siteId: z.string().transform(Number).pipe(z.int().positive())
});
const updateSiteBodySchema = z.strictObject({
type: z.enum(["newt", "wireguard"]),
newtId: z.string().min(1).max(255).optional(),
newtSecret: z.string().min(1).max(255).optional(),
exitNodeId: z.int().positive().optional(),
pubKey: z.string().optional(),
subnet: z.string().optional(),
});
registry.registerPath({
method: "post",
path: "/re-key/{siteId}/regenerate-site-secret",
description: "Regenerate a site's Newt or WireGuard credentials by its site ID.",
tags: [OpenAPITags.Site],
request: {
params: updateSiteParamsSchema,
body: {
content: {
"application/json": {
schema: updateSiteBodySchema,
},
},
},
},
responses: {},
type: z.enum(["newt", "wireguard"]),
secret: z.string().min(1).max(255).optional(),
pubKey: z.string().optional(),
disconnect: z.boolean().optional().default(true)
});
export async function reGenerateSiteSecret(
@@ -65,74 +46,149 @@ export async function reGenerateSiteSecret(
const parsedParams = updateSiteParamsSchema.safeParse(req.params);
if (!parsedParams.success) {
return next(
createHttpError(HttpCode.BAD_REQUEST, fromError(parsedParams.error).toString())
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedParams.error).toString()
)
);
}
const parsedBody = updateSiteBodySchema.safeParse(req.body);
if (!parsedBody.success) {
return next(
createHttpError(HttpCode.BAD_REQUEST, fromError(parsedBody.error).toString())
createHttpError(
HttpCode.BAD_REQUEST,
fromError(parsedBody.error).toString()
)
);
}
const { siteId } = parsedParams.data;
const { type, exitNodeId, pubKey, subnet, newtId, newtSecret } = parsedBody.data;
let updatedSite = undefined;
const { type, pubKey, secret, disconnect } = parsedBody.data;
let existingNewt: Newt | null = null;
if (type === "newt") {
if (!newtSecret) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "newtSecret is required for newt sites")
);
}
const secretHash = await hashPassword(newtSecret);
updatedSite = await db
.update(newts)
.set({
newtId,
secretHash,
})
.where(eq(newts.siteId, siteId))
.returning();
logger.info(`Regenerated Newt credentials for site ${siteId}`);
} else if (type === "wireguard") {
if (!pubKey) {
return next(
createHttpError(HttpCode.BAD_REQUEST, "Public key is required for wireguard sites")
);
}
if (!exitNodeId) {
if (!secret) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Exit node ID is required for wireguard sites"
"newtSecret is required for newt sites"
)
);
}
const secretHash = await hashPassword(secret);
// get the newt to verify it exists
const existingNewts = await db
.select()
.from(newts)
.where(eq(newts.siteId, siteId));
if (existingNewts.length === 0) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`No Newt found for site ID ${siteId}`
)
);
}
if (existingNewts.length > 1) {
return next(
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
`Multiple Newts found for site ID ${siteId}`
)
);
}
existingNewt = existingNewts[0];
// update the secret on the existing newt
await db
.update(newts)
.set({
secretHash
})
.where(eq(newts.newtId, existingNewts[0].newtId));
// Only disconnect if explicitly requested
if (disconnect) {
const payload = {
type: `newt/wg/terminate`,
data: {}
};
// Don't await this to prevent blocking the response
sendToClient(existingNewts[0].newtId, payload).catch(
(error) => {
logger.error(
"Failed to send termination message to newt:",
error
);
}
);
disconnectClient(existingNewts[0].newtId).catch((error) => {
logger.error(
"Failed to disconnect newt after re-key:",
error
);
});
}
logger.info(`Regenerated Newt credentials for site ${siteId}`);
} else if (type === "wireguard") {
if (!pubKey) {
return next(
createHttpError(
HttpCode.BAD_REQUEST,
"Public key is required for wireguard sites"
)
);
}
try {
updatedSite = await db.transaction(async (tx) => {
await addPeer(exitNodeId, {
const [site] = await db
.select()
.from(sites)
.where(eq(sites.siteId, siteId))
.limit(1);
if (!site) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${siteId} not found`
)
);
}
await db
.update(sites)
.set({ pubKey })
.where(eq(sites.siteId, siteId));
if (!site) {
return next(
createHttpError(
HttpCode.NOT_FOUND,
`Site with ID ${siteId} not found`
)
);
}
if (site.exitNodeId && site.subnet) {
await deletePeer(site.exitNodeId, site.pubKey!); // the old pubkey
await addPeer(site.exitNodeId, {
publicKey: pubKey,
allowedIps: subnet ? [subnet] : [],
allowedIps: await getAllowedIps(site.siteId)
});
const result = await tx
.update(sites)
.set({ pubKey })
.where(eq(sites.siteId, siteId))
.returning();
}
return result;
});
logger.info(`Regenerated WireGuard credentials for site ${siteId}`);
logger.info(
`Regenerated WireGuard credentials for site ${siteId}`
);
} catch (err) {
logger.error(
`Transaction failed while regenerating WireGuard secret for site ${siteId}`,
@@ -148,17 +204,21 @@ export async function reGenerateSiteSecret(
}
return response(res, {
data: updatedSite,
data: {
newtId: existingNewt ? existingNewt.newtId : undefined
},
success: true,
error: false,
message: "Credentials regenerated successfully",
status: HttpCode.OK,
status: HttpCode.OK
});
} catch (error) {
logger.error("Unexpected error in reGenerateSiteSecret", error);
return next(
createHttpError(HttpCode.INTERNAL_SERVER_ERROR, "An unexpected error occurred")
createHttpError(
HttpCode.INTERNAL_SERVER_ERROR,
"An unexpected error occurred"
)
);
}
}

View File

@@ -36,9 +36,9 @@ export const paramsSchema = z.object({
});
const bodySchema = z.strictObject({
remoteExitNodeId: z.string().length(15),
secret: z.string().length(48)
});
remoteExitNodeId: z.string().length(15),
secret: z.string().length(48)
});
export type CreateRemoteExitNodeBody = z.infer<typeof bodySchema>;

View File

@@ -25,9 +25,9 @@ import { usageService } from "@server/lib/billing/usageService";
import { FeatureId } from "@server/lib/billing";
const paramsSchema = z.strictObject({
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
export async function deleteRemoteExitNode(
req: Request,

View File

@@ -24,9 +24,9 @@ import { fromError } from "zod-validation-error";
import { GetRemoteExitNodeResponse } from "@server/routers/remoteExitNode/types";
const getRemoteExitNodeSchema = z.strictObject({
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
orgId: z.string().min(1),
remoteExitNodeId: z.string().min(1)
});
async function query(remoteExitNodeId: string) {
const [remoteExitNode] = await db

View File

@@ -55,7 +55,8 @@ export async function getRemoteExitNodeToken(
try {
if (token) {
const { session, remoteExitNode } = await validateRemoteExitNodeSessionToken(token);
const { session, remoteExitNode } =
await validateRemoteExitNodeSessionToken(token);
if (session) {
if (config.getRawConfig().app.log_failed_attempts) {
logger.info(
@@ -103,7 +104,10 @@ export async function getRemoteExitNodeToken(
}
const resToken = generateSessionToken();
await createRemoteExitNodeSession(resToken, existingRemoteExitNode.remoteExitNodeId);
await createRemoteExitNodeSession(
resToken,
existingRemoteExitNode.remoteExitNodeId
);
// logger.debug(`Created RemoteExitNode token response: ${JSON.stringify(resToken)}`);

View File

@@ -33,7 +33,9 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
offlineCheckerInterval = setInterval(async () => {
try {
const twoMinutesAgo = Math.floor((Date.now() - OFFLINE_THRESHOLD_MS) / 1000);
const twoMinutesAgo = Math.floor(
(Date.now() - OFFLINE_THRESHOLD_MS) / 1000
);
// Find clients that haven't pinged in the last 2 minutes and mark them as offline
const newlyOfflineNodes = await db
@@ -48,11 +50,13 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
isNull(exitNodes.lastPing)
)
)
).returning();
)
.returning();
// Update the sites to offline if they have not pinged either
const exitNodeIds = newlyOfflineNodes.map(node => node.exitNodeId);
const exitNodeIds = newlyOfflineNodes.map(
(node) => node.exitNodeId
);
const sitesOnNode = await db
.select()
@@ -63,10 +67,10 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
inArray(sites.exitNodeId, exitNodeIds)
)
);
// loop through the sites and process their lastBandwidthUpdate as an iso string and if its more than 1 minute old then mark the site offline
for (const site of sitesOnNode) {
if (!site.lastBandwidthUpdate) {
if (!site.lastBandwidthUpdate) {
continue;
}
const lastBandwidthUpdate = new Date(site.lastBandwidthUpdate);
@@ -77,13 +81,12 @@ export const startRemoteExitNodeOfflineChecker = (): void => {
.where(eq(sites.siteId, site.siteId));
}
}
} catch (error) {
logger.error("Error in offline checker interval", { error });
}
}, OFFLINE_CHECK_INTERVAL);
logger.info("Started offline checker interval");
logger.debug("Started offline checker interval");
};
/**
@@ -100,7 +103,9 @@ export const stopRemoteExitNodeOfflineChecker = (): void => {
/**
* Handles ping messages from clients and responds with pong
*/
export const handleRemoteExitNodePingMessage: MessageHandler = async (context) => {
export const handleRemoteExitNodePingMessage: MessageHandler = async (
context
) => {
const { message, client: c, sendToClient } = context;
const remoteExitNode = c as RemoteExitNode;
@@ -120,7 +125,7 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
.update(exitNodes)
.set({
lastPing: Math.floor(Date.now() / 1000),
online: true,
online: true
})
.where(eq(exitNodes.exitNodeId, remoteExitNode.exitNodeId));
} catch (error) {
@@ -131,10 +136,10 @@ export const handleRemoteExitNodePingMessage: MessageHandler = async (context) =
message: {
type: "pong",
data: {
timestamp: new Date().toISOString(),
timestamp: new Date().toISOString()
}
},
broadcast: false,
excludeSender: false
};
};
};

View File

@@ -29,7 +29,8 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
return;
}
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } = message.data;
const { remoteExitNodeVersion, remoteExitNodeSecondaryVersion } =
message.data;
if (!remoteExitNodeVersion) {
logger.warn("Remote exit node version not found");
@@ -39,7 +40,10 @@ export const handleRemoteExitNodeRegisterMessage: MessageHandler = async (
// update the version
await db
.update(remoteExitNodes)
.set({ version: remoteExitNodeVersion, secondaryVersion: remoteExitNodeSecondaryVersion })
.set({
version: remoteExitNodeVersion,
secondaryVersion: remoteExitNodeSecondaryVersion
})
.where(
eq(
remoteExitNodes.remoteExitNodeId,

View File

@@ -24,8 +24,8 @@ import { fromError } from "zod-validation-error";
import { ListRemoteExitNodesResponse } from "@server/routers/remoteExitNode/types";
const listRemoteExitNodesParamsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
const listRemoteExitNodesSchema = z.object({
limit: z

View File

@@ -22,8 +22,8 @@ import { z } from "zod";
import { PickRemoteExitNodeDefaultsResponse } from "@server/routers/remoteExitNode/types";
const paramsSchema = z.strictObject({
orgId: z.string()
});
orgId: z.string()
});
export async function pickRemoteExitNodeDefaults(
req: Request,

View File

@@ -38,7 +38,9 @@ export async function quickStartRemoteExitNode(
next: NextFunction
): Promise<any> {
try {
const parsedBody = quickStartRemoteExitNodeBodySchema.safeParse(req.body);
const parsedBody = quickStartRemoteExitNodeBodySchema.safeParse(
req.body
);
if (!parsedBody.success) {
return next(
createHttpError(

View File

@@ -11,4 +11,4 @@
* This file is not licensed under the AGPLv3.
*/
export * from "./ws";
export * from "./ws";

View File

@@ -23,4 +23,4 @@ export const messageHandlers: Record<string, MessageHandler> = {
"remoteExitNode/ping": handleRemoteExitNodePingMessage
};
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes
startRemoteExitNodeOfflineChecker(); // this is to handle the offline check for remote exit nodes

View File

@@ -37,7 +37,14 @@ import { validateRemoteExitNodeSessionToken } from "#private/auth/sessions/remot
import { rateLimitService } from "#private/lib/rateLimit";
import { messageHandlers } from "@server/routers/ws/messageHandlers";
import { messageHandlers as privateMessageHandlers } from "#private/routers/ws/messageHandlers";
import { AuthenticatedWebSocket, ClientType, WSMessage, TokenPayload, WebSocketRequest, RedisMessage } from "@server/routers/ws";
import {
AuthenticatedWebSocket,
ClientType,
WSMessage,
TokenPayload,
WebSocketRequest,
RedisMessage
} from "@server/routers/ws";
import { validateSessionToken } from "@server/auth/sessions/app";
// Merge public and private message handlers
@@ -55,9 +62,9 @@ const processMessage = async (
try {
const message: WSMessage = JSON.parse(data.toString());
logger.debug(
`Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
);
// logger.debug(
// `Processing message from ${clientType.toUpperCase()} ID: ${clientId}, type: ${message.type}`
// );
if (!message.type || typeof message.type !== "string") {
throw new Error("Invalid message format: missing or invalid type");
@@ -216,7 +223,7 @@ const initializeRedisSubscription = async (): Promise<void> => {
// Each node is responsible for restoring its own connection state to Redis
// This approach is more efficient than cross-node coordination because:
// 1. Each node knows its own connections (source of truth)
// 2. No network overhead from broadcasting state between nodes
// 2. No network overhead from broadcasting state between nodes
// 3. No race conditions from simultaneous updates
// 4. Redis becomes eventually consistent as each node restores independently
// 5. Simpler logic with better fault tolerance
@@ -233,8 +240,10 @@ const recoverConnectionState = async (): Promise<void> => {
// Each node simply restores its own local connections to Redis
// This is the source of truth - no need for cross-node coordination
await restoreLocalConnectionsToRedis();
logger.info("Redis connection state recovery completed - restored local state");
logger.info(
"Redis connection state recovery completed - restored local state"
);
} catch (error) {
logger.error("Error during Redis recovery:", error);
} finally {
@@ -251,8 +260,10 @@ const restoreLocalConnectionsToRedis = async (): Promise<void> => {
try {
// Restore all current local connections to Redis
for (const [clientId, clients] of connectedClients.entries()) {
const validClients = clients.filter(client => client.readyState === WebSocket.OPEN);
const validClients = clients.filter(
(client) => client.readyState === WebSocket.OPEN
);
if (validClients.length > 0) {
// Add this node to the client's connection list
await redisManager.sadd(getConnectionsKey(clientId), NODE_ID);
@@ -303,7 +314,10 @@ const addClient = async (
Date.now().toString()
);
} catch (error) {
logger.error("Failed to add client to Redis tracking (connection still functional locally):", error);
logger.error(
"Failed to add client to Redis tracking (connection still functional locally):",
error
);
}
}
@@ -326,9 +340,14 @@ const removeClient = async (
if (redisManager.isRedisEnabled()) {
try {
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
await redisManager.del(getNodeConnectionsKey(NODE_ID, clientId));
await redisManager.del(
getNodeConnectionsKey(NODE_ID, clientId)
);
} catch (error) {
logger.error("Failed to remove client from Redis tracking (cleanup will occur on recovery):", error);
logger.error(
"Failed to remove client from Redis tracking (cleanup will occur on recovery):",
error
);
}
}
@@ -345,7 +364,10 @@ const removeClient = async (
ws.connectionId
);
} catch (error) {
logger.error("Failed to remove specific connection from Redis tracking:", error);
logger.error(
"Failed to remove specific connection from Redis tracking:",
error
);
}
}
@@ -372,7 +394,9 @@ const sendToClientLocal = async (
}
});
logger.debug(`sendToClient: Message type ${message.type} sent to clientId ${clientId}`);
logger.debug(
`sendToClient: Message type ${message.type} sent to clientId ${clientId}`
);
return true;
};
@@ -411,14 +435,22 @@ const sendToClient = async (
fromNodeId: NODE_ID
};
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
} catch (error) {
logger.error("Failed to send message via Redis, message may be lost:", error);
logger.error(
"Failed to send message via Redis, message may be lost:",
error
);
// Continue execution - local delivery already attempted
}
} else if (!localSent && !redisManager.isRedisEnabled()) {
// Redis is disabled or unavailable - log that we couldn't deliver to remote nodes
logger.debug(`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`);
logger.debug(
`Could not deliver message to ${clientId} - not connected locally and Redis unavailable`
);
}
return localSent;
@@ -441,13 +473,21 @@ const broadcastToAllExcept = async (
fromNodeId: NODE_ID
};
await redisManager.publish(REDIS_CHANNEL, JSON.stringify(redisMessage));
await redisManager.publish(
REDIS_CHANNEL,
JSON.stringify(redisMessage)
);
} catch (error) {
logger.error("Failed to broadcast message via Redis, remote nodes may not receive it:", error);
logger.error(
"Failed to broadcast message via Redis, remote nodes may not receive it:",
error
);
// Continue execution - local broadcast already completed
}
} else {
logger.debug("Redis unavailable - broadcast limited to local node only");
logger.debug(
"Redis unavailable - broadcast limited to local node only"
);
}
};
@@ -512,8 +552,10 @@ const verifyToken = async (
return null;
}
if (olm.userId) { // this is a user device and we need to check the user token
const { session: userSession, user } = await validateSessionToken(userToken);
if (olm.userId) {
// this is a user device and we need to check the user token
const { session: userSession, user } =
await validateSessionToken(userToken);
if (!userSession || !user) {
return null;
}
@@ -668,7 +710,7 @@ const handleWSUpgrade = (server: HttpServer): void => {
url.searchParams.get("token") ||
request.headers["sec-websocket-protocol"] ||
"";
const userToken = url.searchParams.get('userToken') || '';
const userToken = url.searchParams.get("userToken") || "";
let clientType = url.searchParams.get(
"clientType"
) as ClientType;
@@ -690,7 +732,11 @@ const handleWSUpgrade = (server: HttpServer): void => {
return;
}
const tokenPayload = await verifyToken(token, clientType, userToken);
const tokenPayload = await verifyToken(
token,
clientType,
userToken
);
if (!tokenPayload) {
logger.debug(
"Unauthorized connection attempt: invalid token..."
@@ -724,50 +770,68 @@ const handleWSUpgrade = (server: HttpServer): void => {
// Add periodic connection state sync to handle Redis disconnections/reconnections
const startPeriodicStateSync = (): void => {
// Lightweight sync every 5 minutes - just restore our own state
setInterval(async () => {
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
try {
await restoreLocalConnectionsToRedis();
logger.debug("Periodic connection state sync completed");
} catch (error) {
logger.error("Error during periodic connection state sync:", error);
setInterval(
async () => {
if (redisManager.isRedisEnabled() && !isRedisRecoveryInProgress) {
try {
await restoreLocalConnectionsToRedis();
logger.debug("Periodic connection state sync completed");
} catch (error) {
logger.error(
"Error during periodic connection state sync:",
error
);
}
}
}
}, 5 * 60 * 1000); // 5 minutes
},
5 * 60 * 1000
); // 5 minutes
// Cleanup stale connections every 15 minutes
setInterval(async () => {
if (redisManager.isRedisEnabled()) {
try {
await cleanupStaleConnections();
logger.debug("Periodic connection cleanup completed");
} catch (error) {
logger.error("Error during periodic connection cleanup:", error);
setInterval(
async () => {
if (redisManager.isRedisEnabled()) {
try {
await cleanupStaleConnections();
logger.debug("Periodic connection cleanup completed");
} catch (error) {
logger.error(
"Error during periodic connection cleanup:",
error
);
}
}
}
}, 15 * 60 * 1000); // 15 minutes
},
15 * 60 * 1000
); // 15 minutes
};
const cleanupStaleConnections = async (): Promise<void> => {
if (!redisManager.isRedisEnabled()) return;
try {
const nodeKeys = await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`) || [];
const nodeKeys =
(await redisManager.getClient()?.keys(`ws:node:${NODE_ID}:*`)) ||
[];
for (const nodeKey of nodeKeys) {
const connections = await redisManager.hgetall(nodeKey);
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, '');
const clientId = nodeKey.replace(`ws:node:${NODE_ID}:`, "");
const localClients = connectedClients.get(clientId) || [];
const localConnectionIds = localClients
.filter(client => client.readyState === WebSocket.OPEN)
.map(client => client.connectionId)
.filter((client) => client.readyState === WebSocket.OPEN)
.map((client) => client.connectionId)
.filter(Boolean);
// Remove Redis entries for connections that no longer exist locally
for (const [connectionId, timestamp] of Object.entries(connections)) {
for (const [connectionId, timestamp] of Object.entries(
connections
)) {
if (!localConnectionIds.includes(connectionId)) {
await redisManager.hdel(nodeKey, connectionId);
logger.debug(`Cleaned up stale connection: ${connectionId} for client: ${clientId}`);
logger.debug(
`Cleaned up stale connection: ${connectionId} for client: ${clientId}`
);
}
}
@@ -776,7 +840,9 @@ const cleanupStaleConnections = async (): Promise<void> => {
if (Object.keys(remainingConnections).length === 0) {
await redisManager.srem(getConnectionsKey(clientId), NODE_ID);
await redisManager.del(nodeKey);
logger.debug(`Cleaned up empty connection tracking for client: ${clientId}`);
logger.debug(
`Cleaned up empty connection tracking for client: ${clientId}`
);
}
}
} catch (error) {
@@ -789,38 +855,38 @@ if (redisManager.isRedisEnabled()) {
initializeRedisSubscription().catch((error) => {
logger.error("Failed to initialize Redis subscription:", error);
});
// Register recovery callback with Redis manager
// When Redis reconnects, each node simply restores its own local state
redisManager.onReconnection(async () => {
logger.info("Redis reconnected, starting WebSocket state recovery...");
await recoverConnectionState();
});
// Start periodic state synchronization
startPeriodicStateSync();
logger.info(
`WebSocket handler initialized with Redis support - Node ID: ${NODE_ID}`
);
} else {
logger.debug(
"WebSocket handler initialized in local mode"
);
logger.debug("WebSocket handler initialized in local mode");
}
// Disconnect a specific client and force them to reconnect
const disconnectClient = async (clientId: string): Promise<boolean> => {
const mapKey = getClientMapKey(clientId);
const clients = connectedClients.get(mapKey);
if (!clients || clients.length === 0) {
logger.debug(`No connections found for client ID: ${clientId}`);
return false;
}
logger.info(`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`);
logger.info(
`Disconnecting client ID: ${clientId} (${clients.length} connection(s))`
);
// Close all connections for this client
clients.forEach((client) => {
if (client.readyState === WebSocket.OPEN) {